• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
17 #define TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
18 
19 #include <functional>
20 #include <iterator>
21 #include <memory>
22 #include <vector>
23 
24 #include "absl/memory/memory.h"
25 #include "absl/types/optional.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/compiler/xla/layout_util.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/lib/gtl/iterator_range.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/types.h"
36 
37 namespace xla {
38 
39 namespace internal {
40 
41 // Internal representation of each node in a ShapeTree.
42 template <typename T>
43 struct ShapeTreeNode {
44   // Data corresponding to this node.
45   std::pair<ShapeIndex, T> data;
46 
47   bool is_leaf = true;
48 
ShapeTreeNodeShapeTreeNode49   explicit ShapeTreeNode(ShapeIndex index)
50       : ShapeTreeNode(std::move(index), T()) {}
ShapeTreeNodeShapeTreeNode51   ShapeTreeNode(ShapeIndex index, T data)
52       : data(std::move(index), std::move(data)) {}
53 };
54 
55 // Internal representation of an index table entry.
56 struct IndexTableEntry {
57   // Index of the node in the ShapeTreeNode vector.
58   uint32 index;
59   // Index of the first child in a IndexTableEntry vector. In the index
60   // table all children entries for a given node will be placed next to each
61   // other. This allows us to use a single field to index them.
62   uint32 children_start;
63 #ifndef NDEBUG
64   // Number of children, used for bounds checking.
65   uint32 children_count;
66 #endif
67 };
68 
69 }  // namespace internal
70 
71 template <typename ContainerType, typename IteratorType, typename ValueType>
72 class ShapeTreeIterator;
73 
74 // A ShapeTree<T> is a recursive data structure which mirrors the structure of a
75 // XLA shape and holds a value of type T for each subshape (i.e. tuple or array)
76 // in the shape. For array shapes, a ShapeTree trivially holds a single value of
77 // type T.
78 //
79 // For tuple shapes which can be an arbitrary tree with arrays at the leaves, a
80 // ShapeTree is an identically structured tree with data elements of type T at
81 // every node. I.e. the root is a tuple by definition, all interior nodes are
82 // also tuples, and all leaves are arrays.
83 //
84 // Like the Shape data structure, this is a tree and tuple elements cannot be
85 // duplicated. That is, every distinct ShapeIndex in the Shape has a unique T
86 // object.
87 //
88 // Normally a ShapeTree owns its Shape, but for efficiency reasons, sometimes
89 // it's helpful not to copy a Shape just to make a ShapeTree.  In these cases,
90 // you can pass a Shape* instead of a Shape& to the ShapeTree constructor.  It's
91 // then up to you to ensure that the pointed-to Shape doesn't die or mutate
92 // before its ShapeTree goes away.
93 template <typename T>
94 class ShapeTree {
95  public:
96   using Node = internal::ShapeTreeNode<T>;
97   using Index = internal::IndexTableEntry;
98 
99   // Default constructor creates a tree with a nil shape (i.e. an empty tuple).
ShapeTree()100   ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {}
101 
102   // Create ShapeTree with the given shape, and default-constructed T values for
103   // all nodes.
104   //
105   // The version that takes a pointer may be cheaper because it doesn't require
106   // any Shape copies, but then it's up to you to ensure that the pointer stays
107   // alive longer than this ShapeTree.
108   explicit ShapeTree(Shape shape);
109   explicit ShapeTree(const Shape* shape);
110   explicit ShapeTree(const std::shared_ptr<Shape>& shape);
111 
112   // Create ShapeTree with the given shape, and init_value for all nodes.
113   ShapeTree(Shape shape, const T& init_value);
114   ShapeTree(const Shape* shape, const T& init_value);
115   ShapeTree(const std::shared_ptr<Shape>& shape, const T& init_value);
116 
117   // Returns the data element associated with the array in the shape at the
118   // given index (see ShapeUtil::GetSubshape for how indexes are defined).
119   const T& element(ShapeIndexView index) const;
120   T* mutable_element(ShapeIndexView index);
121 
122   // Return the shape represented with this ShapeTree.
shape()123   const Shape& shape() const { return *shape_; }
124 
125   // A ShapeTree object can own the underlying Shape pointer (via the
126   // shape_storage_ member), or can point to a Shape object owned by the caller.
127   // This API replaces the underlying Shape object to the one supplied by the
128   // caller, whom must ensure the object remain valid for the whole lifetime of
129   // this ShapeTree object, and also that the Shape is consistent with it.
replace_shape_ptr(const Shape * shape)130   void replace_shape_ptr(const Shape* shape) {
131     if (shape_storage_ != nullptr) {
132       DCHECK_EQ(*shape, *shape_storage_);
133       shape_storage_ = nullptr;
134     }
135     shape_ = shape;
136   }
137 
138   // Returns true if the node at the given index is a leaf node (an array
139   // shape).
IsLeaf(ShapeIndexView index)140   bool IsLeaf(ShapeIndexView index) const { return Lookup(index)->is_leaf; }
141 
142   ShapeTree(const ShapeTree&) = default;
143   ShapeTree& operator=(const ShapeTree&) = default;
144   ShapeTree(ShapeTree&&) = default;
145   ShapeTree& operator=(ShapeTree&& other) = default;
146 
147   // iterator implements a bidirectional_iterator with
148   //  value_type = std::pair<ShapeIndex, T>.
149   //
150   // The iteration order is guaranteed to be a pre-order walk of the ShapeTree.
151   using iterator =
152       ShapeTreeIterator<std::vector<Node>, typename std::vector<Node>::iterator,
153                         std::pair<ShapeIndex, T>>;
154   using const_iterator =
155       ShapeTreeIterator<const std::vector<Node>,
156                         typename std::vector<Node>::const_iterator,
157                         const std::pair<ShapeIndex, T>>;
158   using reverse_iterator = std::reverse_iterator<iterator>;
159   using const_reverse_iterator = std::reverse_iterator<const_iterator>;
160 
161   // begin/end for iterating over all nodes.
begin()162   iterator begin() {
163     return iterator(&nodes_, nodes_.begin(),
164                     /*iterate_leaves_only=*/false);
165   }
end()166   iterator end() {
167     return iterator(&nodes_, nodes_.end(),
168                     /*iterate_leaves_only=*/false);
169   }
begin()170   const_iterator begin() const {
171     return const_iterator(&nodes_, nodes_.begin(),
172                           /*iterate_leaves_only=*/false);
173   }
end()174   const_iterator end() const {
175     return const_iterator(&nodes_, nodes_.end(),
176                           /*iterate_leaves_only=*/false);
177   }
178 
179   // rbegin/rend for iterating over all nodes in reverse.
rbegin()180   reverse_iterator rbegin() { return reverse_iterator(end()); }
rend()181   reverse_iterator rend() { return reverse_iterator(begin()); }
rbegin()182   const_reverse_iterator rbegin() const {
183     return const_reverse_iterator(end());
184   }
rend()185   const_reverse_iterator rend() const {
186     return const_reverse_iterator(begin());
187   }
188 
189   // leaf_begin()/leaf_end() iterates over all leaf nodes (nodes with no
190   // children).
leaf_begin()191   iterator leaf_begin() {
192     return iterator(&nodes_, nodes_.begin(),
193                     /*iterate_leaves_only=*/true);
194   }
leaf_end()195   iterator leaf_end() {
196     return iterator(&nodes_, nodes_.end(),
197                     /*iterate_leaves_only=*/true);
198   }
leaf_begin()199   const_iterator leaf_begin() const {
200     return const_iterator(&nodes_, nodes_.begin(),
201                           /*iterate_leaves_only=*/true);
202   }
leaf_end()203   const_iterator leaf_end() const {
204     return const_iterator(&nodes_, nodes_.end(),
205                           /*iterate_leaves_only=*/true);
206   }
207   // range-based iterator for leaf_begin()/leaf_end().
leaves()208   tensorflow::gtl::iterator_range<iterator> leaves() {
209     return tensorflow::gtl::make_range(leaf_begin(), leaf_end());
210   }
leaves()211   tensorflow::gtl::iterator_range<const_iterator> leaves() const {
212     return tensorflow::gtl::make_range(leaf_begin(), leaf_end());
213   }
214 
leaf_rbegin()215   reverse_iterator leaf_rbegin() { return reverse_iterator(leaf_end()); }
leaf_rend()216   reverse_iterator leaf_rend() { return reverse_iterator(leaf_begin()); }
leaf_rbegin()217   const_reverse_iterator leaf_rbegin() const {
218     return const_reverse_iterator(leaf_end());
219   }
leaf_rend()220   const_reverse_iterator leaf_rend() const {
221     return const_reverse_iterator(leaf_begin());
222   }
223 
224   // Returns an iterator pointing to the given ShapeIndex.
225   // REQUIRES: index must exist in the ShapeTree.
find(ShapeIndexView index)226   iterator find(ShapeIndexView index) {
227     Node* element = Lookup(index);
228     auto element_iter = nodes_.begin() + (element - &nodes_[0]);
229     return iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false);
230   }
find(ShapeIndexView index)231   const_iterator find(ShapeIndexView index) const {
232     Node* element = Lookup(index);
233     auto element_iter = nodes_.cbegin() + (element - &nodes_[0]);
234     return const_iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false);
235   }
236 
237   // Returns the number of leaf nodes in the tree.
leaf_count()238   int64 leaf_count() const { return std::distance(leaf_begin(), leaf_end()); }
239 
240   // Recursively traverses the shape and calls the given function at each
241   // element. The function has the following arguments:
242   //
243   //   Fn :    A callable of type void(const ShapeIndex& index, const T& data)
244   //           (or compatible).
245   //   index : the index of the element in the shape. See ShapeUtil::GetSubshape
246   //           for definition of index.
247   //   data : The data value at this element.
248   template <typename Fn>
249   void ForEachElement(const Fn& func) const;
250 
251   // Like ForEachElement, but the callable has type
252   //
253   //   void (const ShapeIndex& index, T* data).
254   //
255   template <typename Fn>
256   void ForEachMutableElement(const Fn& func);
257 
258   // Like ForEach(Mutable)Element, but the callable returns a Status instead of
259   // void.  The first non-OK return value is returned by the ForEach* function.
260   template <typename Fn>
261   Status ForEachElementWithStatus(const Fn& func) const;
262   template <typename Fn>
263   Status ForEachMutableElementWithStatus(const Fn& func);
264 
265   // Maps each element to generate a new tree with the same shape.
266   template <typename U>
Map(const std::function<U (const T &)> & func)267   ShapeTree<U> Map(const std::function<U(const T&)>& func) {
268     ShapeTree<U> result(shape_storage_);
269     ForEachElement([&](const ShapeIndex& index, const T& t) {
270       *result.mutable_element(index) = func(t);
271     });
272     return result;
273   }
274 
275   template <typename U>
Map(const std::function<U (T *)> & func)276   ShapeTree<U> Map(const std::function<U(T*)>& func) {
277     ShapeTree<U> result(shape_storage_);
278     ForEachMutableElement([&](const ShapeIndex& index, T* t) {
279       *result.mutable_element(index) = func(t);
280     });
281     return result;
282   }
283 
284   // Copy the subtree of values from 'other' rooted at ShapeIndex
285   // 'source_base_index' into the subtree of value in this ShapeTree rooted at
286   // 'target_base_index'.
287   //
288   // Precondition: The subshape of other.shape() at index source_base_index must
289   // be compatible with the subshape of shape() at index target_base_index.
290   void CopySubtreeFrom(const ShapeTree<T>& other,
291                        const ShapeIndex& source_base_index,
292                        const ShapeIndex& target_base_index);
293 
294   StatusOr<ShapeTree<T>> SubShapeTree(const ShapeIndex& index) const;
295 
296   bool operator==(const ShapeTree<T>& other) const;
297   bool operator!=(const ShapeTree<T>& other) const { return !(*this == other); }
298 
299  private:
300   // Initialize node->children based on 'shape'. All children are assigned the
301   // the given 'init_value'.
302   void InitChildren(const Shape& shape, const T& init_value, Node* node,
303                     Index* index);
304 
305   // Initialize node->children based on 'shape'. All children have
306   // default-constructed data values.
307   void InitChildren(const Shape& shape, Node* node, Index* index);
308 
309   // Returns the number of subshapes, including interior nodes, in shape.
310   int64 CountSubshapes(const Shape& shape);
311 
312   // Helpers for traversing the shape via ForEachElement. The helpers
313   // recursively traverse the subtree rooted at "index" (defined as in
314   // ShapeUtil::GetSubshape).
315   template <typename Fn>
316   static Status ForEachHelper(const Fn& func, const std::vector<Node>& nodes);
317   template <typename Fn>
318   static Status ForEachMutableHelper(const Fn& func, std::vector<Node>* nodes);
319 
320   // Return the tree node at the given index.
321   Node* Lookup(ShapeIndexView index);
322   const Node* Lookup(ShapeIndexView index) const;
323 
324   // The nodes in this shape tree.
325   std::vector<Node> nodes_;
326 
327   // Index table for node lookups.
328   std::vector<Index> index_table_;
329 
330   // If we own our Shape, this field contains it, and shape_ is a pointer into
331   // here.  Otherwise if we don't own our shape, this is nullptr.
332   std::shared_ptr<Shape> shape_storage_;
333 
334   // The XLA shape mirrored in this ShapeTree.  This is either
335   // shape_storage_.get() or the Shape pointer passed to our constructor.
336   const Shape* shape_;
337 };
338 
339 // Internal iterator that performs a pre-order walk. This is cheap to copy.
340 // The iterator value_type is equivalent to a
341 // std::pair<ShapeIndex,T>&, similar to std::map.
342 template <typename ContainerType, typename IteratorType, typename ValueType>
343 class ShapeTreeIterator
344     : public std::iterator<std::bidirectional_iterator_tag, ValueType> {
345  public:
ShapeTreeIterator(ContainerType * nodes,IteratorType node,bool iterate_leaves_only)346   ShapeTreeIterator(ContainerType* nodes, IteratorType node,
347                     bool iterate_leaves_only)
348       : nodes_(nodes),
349         node_(std::move(node)),
350         iterate_leaves_only_(iterate_leaves_only) {
351     while (iterate_leaves_only && node_ != nodes_->end() && !node_->is_leaf) {
352       ++node_;
353     }
354   }
355 
356   ShapeTreeIterator& operator++() {
357     ++node_;
358     while (iterate_leaves_only_ && node_ != nodes_->end() && !node_->is_leaf) {
359       ++node_;
360     }
361     return *this;
362   }
363   ShapeTreeIterator operator++(int) {
364     auto i = *this;
365     ++(*this);
366     return i;
367   }
368 
369   ShapeTreeIterator& operator--() {
370     --node_;
371     while (iterate_leaves_only_ && node_ > nodes_->begin() && !node_->is_leaf) {
372       --node_;
373     }
374     return *this;
375   }
376   ShapeTreeIterator operator--(int) {
377     auto i = *this;
378     --(*this);
379     return i;
380   }
381 
382   bool operator==(const ShapeTreeIterator& other) const {
383     return node_ == other.node_;
384   }
385   bool operator!=(const ShapeTreeIterator& other) const {
386     return node_ != other.node_;
387   }
388   ValueType& operator*() { return node_->data; }
389   ValueType* operator->() { return &node_->data; }
390 
391  private:
392   ContainerType* nodes_;
393   IteratorType node_;
394   // True if we should not include interior nodes in our walk.
395   const bool iterate_leaves_only_;
396 };
397 
398 template <typename T>
CountSubshapes(const Shape & shape)399 int64 ShapeTree<T>::CountSubshapes(const Shape& shape) {
400   int64 current_count = 1;
401   if (shape.IsTuple()) {
402     int64 count = ShapeUtil::TupleElementCount(shape);
403     for (int i = 0; i < count; ++i) {
404       current_count += CountSubshapes(shape.tuple_shapes(i));
405     }
406   }
407   return current_count;
408 }
409 
410 template <typename T>
InitChildren(const Shape & shape,const T & init_value,Node * node,Index * index)411 void ShapeTree<T>::InitChildren(const Shape& shape, const T& init_value,
412                                 Node* node, Index* index) {
413   if (shape.IsTuple()) {
414     const int64 size = ShapeUtil::TupleElementCount(shape);
415 #ifndef NDEBUG
416     index->children_count = size;
417 #endif
418     node->is_leaf = false;
419     ShapeIndex shape_index = node->data.first;
420     shape_index.push_back(0);
421 
422     // At the end of the index_table, reserve a continuous space to hold the
423     // children of current node. In order to enforce the invariant that all
424     // children of a given node are placed together, we need to do the
425     // reservation before we recurse into any of its children.
426     int64 children_start_position = index_table_.size();
427     index_table_.resize(index_table_.size() + size);
428 
429     for (int i = 0; i < size; ++i) {
430       shape_index[shape_index.size() - 1] = i;
431       index_table_[children_start_position + i].index = nodes_.size();
432       // The first child of the node in the index table is placed at the end of
433       // the table.
434       index_table_[children_start_position + i].children_start =
435           index_table_.size();
436       nodes_.emplace_back(shape_index, init_value);
437       InitChildren(shape.tuple_shapes(i), init_value, &nodes_.back(),
438                    &index_table_[children_start_position + i]);
439     }
440   } else {
441 #ifndef NDEBUG
442     index->children_count = 0;
443 #endif
444   }
445 }
446 
447 template <typename T>
InitChildren(const Shape & shape,Node * node,Index * index)448 void ShapeTree<T>::InitChildren(const Shape& shape, Node* node, Index* index) {
449   if (shape.IsTuple()) {
450     const int64 size = ShapeUtil::TupleElementCount(shape);
451 #ifndef NDEBUG
452     index->children_count = size;
453 #endif
454     node->is_leaf = false;
455     ShapeIndex shape_index = node->data.first;
456     shape_index.push_back(0);
457 
458     // At the end of the index_table, reserve a continuous space to hold the
459     // children of current node. In order to enforce the invariant that all
460     // children of a given node are placed together, we need to do the
461     // reservation before we recurse into any of its children.
462     int64 children_start_position = index_table_.size();
463     index_table_.resize(index_table_.size() + size);
464 
465     for (int i = 0; i < size; ++i) {
466       shape_index[shape_index.size() - 1] = i;
467       index_table_[children_start_position + i].index = nodes_.size();
468       // The first child of the node in the index table is placed at the end of
469       // the table.
470       index_table_[children_start_position + i].children_start =
471           index_table_.size();
472       nodes_.emplace_back(shape_index);
473       InitChildren(shape.tuple_shapes(i), &nodes_.back(),
474                    &index_table_[children_start_position + i]);
475     }
476   } else {
477 #ifndef NDEBUG
478     index->children_count = 0;
479 #endif
480   }
481 }
482 
483 template <typename T>
ShapeTree(Shape shape)484 ShapeTree<T>::ShapeTree(Shape shape)
485     : shape_storage_(std::make_shared<Shape>(std::move(shape))),
486       shape_(shape_storage_.get()) {
487   const int64 count = CountSubshapes(*shape_);
488   nodes_.reserve(count);
489   nodes_.emplace_back(ShapeIndex{});
490 
491   index_table_.reserve(count);
492   index_table_.emplace_back(Index{0, 1});
493   InitChildren(*shape_, &nodes_[0], &index_table_[0]);
494 }
495 
496 template <typename T>
ShapeTree(const Shape * shape)497 ShapeTree<T>::ShapeTree(const Shape* shape) : shape_(shape) {
498   const int64 count = CountSubshapes(*shape_);
499   nodes_.reserve(count);
500   nodes_.emplace_back(ShapeIndex{});
501 
502   index_table_.reserve(count);
503   index_table_.emplace_back(Index{0, 1});
504   InitChildren(*shape_, &nodes_[0], &index_table_[0]);
505 }
506 
507 template <typename T>
ShapeTree(const std::shared_ptr<Shape> & shape)508 ShapeTree<T>::ShapeTree(const std::shared_ptr<Shape>& shape)
509     : shape_storage_(shape), shape_(shape_storage_.get()) {
510   const int64 count = CountSubshapes(*shape_);
511   nodes_.reserve(count);
512   nodes_.emplace_back(ShapeIndex{});
513 
514   index_table_.reserve(count);
515   index_table_.emplace_back(Index{0, 1});
516   InitChildren(*shape_, &nodes_[0], &index_table_[0]);
517 }
518 
519 template <typename T>
ShapeTree(Shape shape,const T & init_value)520 ShapeTree<T>::ShapeTree(Shape shape, const T& init_value)
521     : shape_storage_(std::make_shared<Shape>(std::move(shape))),
522       shape_(shape_storage_.get()) {
523   const int64 count = CountSubshapes(*shape_);
524   nodes_.reserve(count);
525   nodes_.emplace_back(ShapeIndex{}, init_value);
526 
527   index_table_.reserve(count);
528   index_table_.emplace_back(Index{0, 1});
529   InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]);
530 }
531 
532 template <typename T>
ShapeTree(const Shape * shape,const T & init_value)533 ShapeTree<T>::ShapeTree(const Shape* shape, const T& init_value)
534     : shape_(shape) {
535   const int64 count = CountSubshapes(*shape_);
536   nodes_.reserve(count);
537   nodes_.emplace_back(ShapeIndex{}, init_value);
538 
539   index_table_.reserve(count);
540   index_table_.emplace_back(Index{0, 1});
541   InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]);
542 }
543 
544 template <typename T>
ShapeTree(const std::shared_ptr<Shape> & shape,const T & init_value)545 ShapeTree<T>::ShapeTree(const std::shared_ptr<Shape>& shape,
546                         const T& init_value)
547     : shape_storage_(shape), shape_(shape_storage_.get()) {
548   const int64 count = CountSubshapes(*shape_);
549   nodes_.reserve(count);
550   nodes_.emplace_back(ShapeIndex{}, init_value);
551 
552   index_table_.reserve(count);
553   index_table_.emplace_back(Index{0, 1});
554   InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]);
555 }
556 
557 template <typename T>
element(ShapeIndexView index)558 const T& ShapeTree<T>::element(ShapeIndexView index) const {
559   return Lookup(index)->data.second;
560 }
561 
562 template <typename T>
mutable_element(ShapeIndexView index)563 T* ShapeTree<T>::mutable_element(ShapeIndexView index) {
564   return &Lookup(index)->data.second;
565 }
566 
567 template <typename T>
Lookup(ShapeIndexView index)568 internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(ShapeIndexView index) {
569   Index* iter = &index_table_[0];
570   for (const int64 i : index) {
571     CHECK_GE(i, 0);
572 #ifndef NDEBUG
573     CHECK_LT(i, iter->children_count);
574 #endif
575     iter = &index_table_[iter->children_start + i];
576   }
577 
578   return &nodes_[iter->index];
579 }
580 
581 template <typename T>
Lookup(ShapeIndexView index)582 const internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(
583     ShapeIndexView index) const {
584   return const_cast<ShapeTree*>(this)->Lookup(index);
585 }
586 
587 /* static */
588 template <typename T>
589 template <typename Fn>
ForEachHelper(const Fn & func,const std::vector<Node> & nodes)590 Status ShapeTree<T>::ForEachHelper(const Fn& func,
591                                    const std::vector<Node>& nodes) {
592   for (const auto& node : nodes) {
593     TF_RETURN_IF_ERROR(func(node.data.first, node.data.second));
594   }
595   return Status::OK();
596 }
597 
598 /* static */
599 template <typename T>
600 template <typename Fn>
ForEachMutableHelper(const Fn & func,std::vector<Node> * nodes)601 Status ShapeTree<T>::ForEachMutableHelper(const Fn& func,
602                                           std::vector<Node>* nodes) {
603   for (auto& node : *nodes) {
604     TF_RETURN_IF_ERROR(func(node.data.first, &node.data.second));
605   }
606   return Status::OK();
607 }
608 
609 template <typename T>
610 template <typename Fn>
ForEachElementWithStatus(const Fn & func)611 Status ShapeTree<T>::ForEachElementWithStatus(const Fn& func) const {
612   return ForEachHelper(func, nodes_);
613 }
614 
615 template <typename T>
616 template <typename Fn>
ForEachMutableElementWithStatus(const Fn & func)617 Status ShapeTree<T>::ForEachMutableElementWithStatus(const Fn& func) {
618   return ForEachMutableHelper(func, &nodes_);
619 }
620 
621 template <typename T>
622 template <typename Fn>
ForEachElement(const Fn & func)623 void ShapeTree<T>::ForEachElement(const Fn& func) const {
624   return ForEachHelper(
625              [&func](const ShapeIndex& index, const T& data) {
626                func(index, data);
627                return Status::OK();
628              },
629              nodes_)
630       .IgnoreError();
631 }
632 
633 template <typename T>
634 template <typename Fn>
ForEachMutableElement(const Fn & func)635 void ShapeTree<T>::ForEachMutableElement(const Fn& func) {
636   return ForEachMutableHelper(
637              [&func](const ShapeIndex& index, T* data) {
638                func(index, data);
639                return Status::OK();
640              },
641              &nodes_)
642       .IgnoreError();
643 }
644 
645 template <typename T>
CopySubtreeFrom(const ShapeTree<T> & other,const ShapeIndex & source_base_index,const ShapeIndex & target_base_index)646 void ShapeTree<T>::CopySubtreeFrom(const ShapeTree<T>& other,
647                                    const ShapeIndex& source_base_index,
648                                    const ShapeIndex& target_base_index) {
649   CHECK(ShapeUtil::Compatible(
650       ShapeUtil::GetSubshape(shape(), target_base_index),
651       ShapeUtil::GetSubshape(other.shape(), source_base_index)));
652   ForEachMutableElement([this, &other, &source_base_index, &target_base_index](
653                             const ShapeIndex& index, T* data) {
654     // Copy the data element only if index is in the
655     // subtree rooted at target_base_index.
656     for (int i = 0; i < target_base_index.size(); ++i) {
657       if (i >= index.size() || index[i] != target_base_index[i]) {
658         return;
659       }
660     }
661     // Construct source element index to copy from.
662     ShapeIndex source_index = source_base_index;
663     for (int i = target_base_index.size(); i < index.size(); ++i) {
664       source_index.push_back(index[i]);
665     }
666     *data = other.element(source_index);
667   });
668 }
669 
670 template <typename T>
SubShapeTree(const ShapeIndex & index)671 StatusOr<ShapeTree<T>> ShapeTree<T>::SubShapeTree(
672     const ShapeIndex& index) const {
673   TF_ASSIGN_OR_RETURN(const Shape* sub_shape,
674                       ShapeUtil::TryGetSubshape(shape(), index));
675   ShapeTree<T> sub_shape_tree(*sub_shape);
676   sub_shape_tree.CopySubtreeFrom(*this, index, {});
677   return std::move(sub_shape_tree);
678 }
679 
680 template <typename T>
681 bool ShapeTree<T>::operator==(const ShapeTree<T>& other) const {
682   bool equal = true;
683   ForEachElement([&other, &equal](const ShapeIndex& index, const T& data) {
684     if (data != other.element(index)) {
685       equal = false;
686     }
687   });
688   return equal;
689 }
690 
691 }  // namespace xla
692 
693 #endif  // TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
694