• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
17 #define TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
18 
19 #include <functional>
20 #include <iterator>
21 #include <memory>
22 #include <vector>
23 
24 #include "absl/memory/memory.h"
25 #include "absl/types/optional.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/compiler/xla/layout_util.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/lib/gtl/iterator_range.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/types.h"
36 
37 namespace xla {
38 
39 namespace internal {
40 
41 // Internal representation of each node in a ShapeTree.
42 template <typename T>
43 struct ShapeTreeNode {
44   // Data corresponding to this node.
45   std::pair<ShapeIndex, T> data;
46 
47   bool is_leaf = true;
48 
ShapeTreeNodeShapeTreeNode49   explicit ShapeTreeNode(ShapeIndex index)
50       : ShapeTreeNode(std::move(index), T()) {}
ShapeTreeNodeShapeTreeNode51   ShapeTreeNode(ShapeIndex index, T data)
52       : data(std::move(index), std::move(data)) {}
53 };
54 
55 // Internal representation of an index table entry.
56 struct IndexTableEntry {
57   // Index of the node in the ShapeTreeNode vector.
58   uint32 index;
59   // Index of the first child in a IndexTableEntry vector. In the index
60   // table all children entries for a given node will be placed next to each
61   // other. This allows us to use a single field to index them.
62   uint32 children_start;
63 #ifndef NDEBUG
64   // Number of children, used for bounds checking.
65   uint32 children_count;
66 #endif
67 };
68 
69 }  // namespace internal
70 
71 template <typename ContainerType, typename IteratorType, typename ValueType>
72 class ShapeTreeIterator;
73 
74 // A ShapeTree<T> is a recursive data structure which mirrors the structure of a
75 // XLA shape and holds a value of type T for each subshape (i.e. tuple or array)
76 // in the shape. For array shapes, a ShapeTree trivially holds a single value of
77 // type T.
78 //
79 // For tuple shapes which can be an arbitrary tree with arrays at the leaves, a
80 // ShapeTree is an identically structured tree with data elements of type T at
81 // every node. I.e. the root is a tuple by definition, all interior nodes are
82 // also tuples, and all leaves are arrays.
83 //
84 // Like the Shape data structure, this is a tree and tuple elements cannot be
85 // duplicated. That is, every distinct ShapeIndex in the Shape has a unique T
86 // object.
87 //
88 // Normally a ShapeTree owns its Shape, but for efficiency reasons, sometimes
89 // it's helpful not to copy a Shape just to make a ShapeTree.  In these cases,
90 // you can pass a Shape* instead of a Shape& to the ShapeTree constructor.  It's
91 // then up to you to ensure that the pointed-to Shape doesn't die or mutate
92 // before its ShapeTree goes away.
93 template <typename T>
94 class ShapeTree {
95  public:
96   using Node = internal::ShapeTreeNode<T>;
97   using Index = internal::IndexTableEntry;
98 
99   // Default constructor creates a tree with a nil shape (i.e. an empty tuple).
ShapeTree()100   ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {}
101 
102   // Create ShapeTree with the given shape, and default-constructed T values for
103   // all nodes.
104   //
105   // The version that takes a pointer may be cheaper because it doesn't require
106   // any Shape copies, but then it's up to you to ensure that the pointer stays
107   // alive longer than this ShapeTree.
108   explicit ShapeTree(Shape shape);
109   explicit ShapeTree(const Shape* shape);
110   explicit ShapeTree(const std::shared_ptr<Shape>& shape);
111 
112   // Create ShapeTree with the given shape, and init_value for all nodes.
113   ShapeTree(Shape shape, const T& init_value);
114   ShapeTree(const Shape* shape, const T& init_value);
115   ShapeTree(const std::shared_ptr<Shape>& shape, const T& init_value);
116 
117   // Returns the data element associated with the array in the shape at the
118   // given index (see ShapeUtil::GetSubshape for how indexes are defined).
119   const T& element(ShapeIndexView index) const;
120   T* mutable_element(ShapeIndexView index);
121 
122   // Return the shape represented with this ShapeTree.
shape()123   const Shape& shape() const { return *shape_; }
124 
125   // Replaces *only* the underlying shape of this ShapeTree. The caller must own
126   // the Shape object and hence shape_storage_ is not updated.
127   //
128   // Only safe to use this if the ShapeTree was constructed with 'explicit
129   // ShapeTree(const Shape* shape)' or is moved from one such ShapeTree. The
130   // caller must ensure that the input shape is consistent with the underlying
131   // tree.
replace_shape_ptr(const Shape * shape)132   void replace_shape_ptr(const Shape* shape) {
133     CHECK(shape_storage_.get() == nullptr);
134     shape_ = shape;
135   }
136 
137   // Returns true if the node at the given index is a leaf node (an array
138   // shape).
IsLeaf(ShapeIndexView index)139   bool IsLeaf(ShapeIndexView index) const { return Lookup(index)->is_leaf; }
140 
141   ShapeTree(const ShapeTree&) = default;
142   ShapeTree& operator=(const ShapeTree&) = default;
143   ShapeTree(ShapeTree&&) = default;
144   ShapeTree& operator=(ShapeTree&& other) = default;
145 
146   // iterator implements a bidirectional_iterator with
147   //  value_type = std::pair<ShapeIndex, T>.
148   //
149   // The iteration order is guaranteed to be a pre-order walk of the ShapeTree.
150   using iterator =
151       ShapeTreeIterator<std::vector<Node>, typename std::vector<Node>::iterator,
152                         std::pair<ShapeIndex, T>>;
153   using const_iterator =
154       ShapeTreeIterator<const std::vector<Node>,
155                         typename std::vector<Node>::const_iterator,
156                         const std::pair<ShapeIndex, T>>;
157   using reverse_iterator = std::reverse_iterator<iterator>;
158   using const_reverse_iterator = std::reverse_iterator<const_iterator>;
159 
160   // begin/end for iterating over all nodes.
begin()161   iterator begin() {
162     return iterator(&nodes_, nodes_.begin(),
163                     /*iterate_leaves_only=*/false);
164   }
end()165   iterator end() {
166     return iterator(&nodes_, nodes_.end(),
167                     /*iterate_leaves_only=*/false);
168   }
begin()169   const_iterator begin() const {
170     return const_iterator(&nodes_, nodes_.begin(),
171                           /*iterate_leaves_only=*/false);
172   }
end()173   const_iterator end() const {
174     return const_iterator(&nodes_, nodes_.end(),
175                           /*iterate_leaves_only=*/false);
176   }
177 
178   // rbegin/rend for iterating over all nodes in reverse.
rbegin()179   reverse_iterator rbegin() { return reverse_iterator(end()); }
rend()180   reverse_iterator rend() { return reverse_iterator(begin()); }
rbegin()181   const_reverse_iterator rbegin() const {
182     return const_reverse_iterator(end());
183   }
rend()184   const_reverse_iterator rend() const {
185     return const_reverse_iterator(begin());
186   }
187 
188   // leaf_begin()/leaf_end() iterates over all leaf nodes (nodes with no
189   // children).
leaf_begin()190   iterator leaf_begin() {
191     return iterator(&nodes_, nodes_.begin(),
192                     /*iterate_leaves_only=*/true);
193   }
leaf_end()194   iterator leaf_end() {
195     return iterator(&nodes_, nodes_.end(),
196                     /*iterate_leaves_only=*/true);
197   }
leaf_begin()198   const_iterator leaf_begin() const {
199     return const_iterator(&nodes_, nodes_.begin(),
200                           /*iterate_leaves_only=*/true);
201   }
leaf_end()202   const_iterator leaf_end() const {
203     return const_iterator(&nodes_, nodes_.end(),
204                           /*iterate_leaves_only=*/true);
205   }
206   // range-based iterator for leaf_begin()/leaf_end().
leaves()207   tensorflow::gtl::iterator_range<iterator> leaves() {
208     return tensorflow::gtl::make_range(leaf_begin(), leaf_end());
209   }
leaves()210   tensorflow::gtl::iterator_range<const_iterator> leaves() const {
211     return tensorflow::gtl::make_range(leaf_begin(), leaf_end());
212   }
213 
leaf_rbegin()214   reverse_iterator leaf_rbegin() { return reverse_iterator(leaf_end()); }
leaf_rend()215   reverse_iterator leaf_rend() { return reverse_iterator(leaf_begin()); }
leaf_rbegin()216   const_reverse_iterator leaf_rbegin() const {
217     return const_reverse_iterator(leaf_end());
218   }
leaf_rend()219   const_reverse_iterator leaf_rend() const {
220     return const_reverse_iterator(leaf_begin());
221   }
222 
223   // Returns an iterator pointing to the given ShapeIndex.
224   // REQUIRES: index must exist in the ShapeTree.
find(ShapeIndexView index)225   iterator find(ShapeIndexView index) {
226     Node* element = Lookup(index);
227     auto element_iter = nodes_.begin() + (element - &nodes_[0]);
228     return iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false);
229   }
find(ShapeIndexView index)230   const_iterator find(ShapeIndexView index) const {
231     Node* element = Lookup(index);
232     auto element_iter = nodes_.cbegin() + (element - &nodes_[0]);
233     return const_iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false);
234   }
235 
236   // Returns the number of leaf nodes in the tree.
leaf_count()237   int64 leaf_count() const { return std::distance(leaf_begin(), leaf_end()); }
238 
239   // Recursively traverses the shape and calls the given function at each
240   // element. The function has the following arguments:
241   //
242   //   Fn :    A callable of type void(const ShapeIndex& index, const T& data)
243   //           (or compatible).
244   //   index : the index of the element in the shape. See ShapeUtil::GetSubshape
245   //           for definition of index.
246   //   data : The data value at this element.
247   template <typename Fn>
248   void ForEachElement(const Fn& func) const;
249 
250   // Like ForEachElement, but the callable has type
251   //
252   //   void (const ShapeIndex& index, T* data).
253   //
254   template <typename Fn>
255   void ForEachMutableElement(const Fn& func);
256 
257   // Like ForEach(Mutable)Element, but the callable returns a Status instead of
258   // void.  The first non-OK return value is returned by the ForEach* function.
259   template <typename Fn>
260   Status ForEachElementWithStatus(const Fn& func) const;
261   template <typename Fn>
262   Status ForEachMutableElementWithStatus(const Fn& func);
263 
264   // Maps each element to generate a new tree with the same shape.
265   template <typename U>
Map(const std::function<U (const T &)> & func)266   ShapeTree<U> Map(const std::function<U(const T&)>& func) {
267     ShapeTree<U> result(shape_storage_);
268     ForEachElement([&](const ShapeIndex& index, const T& t) {
269       *result.mutable_element(index) = func(t);
270     });
271     return result;
272   }
273 
274   template <typename U>
Map(const std::function<U (T *)> & func)275   ShapeTree<U> Map(const std::function<U(T*)>& func) {
276     ShapeTree<U> result(shape_storage_);
277     ForEachMutableElement([&](const ShapeIndex& index, T* t) {
278       *result.mutable_element(index) = func(t);
279     });
280     return result;
281   }
282 
283   // Copy the subtree of values from 'other' rooted at ShapeIndex
284   // 'source_base_index' into the subtree of value in this ShapeTree rooted at
285   // 'target_base_index'.
286   //
287   // Precondition: The subshape of other.shape() at index source_base_index must
288   // be compatible with the subshape of shape() at index target_base_index.
289   void CopySubtreeFrom(const ShapeTree<T>& other,
290                        const ShapeIndex& source_base_index,
291                        const ShapeIndex& target_base_index);
292 
293   bool operator==(const ShapeTree<T>& other) const;
294   bool operator!=(const ShapeTree<T>& other) const { return !(*this == other); }
295 
296  private:
297   // Initialize node->children based on 'shape'. All children are assigned the
298   // the given 'init_value'.
299   void InitChildren(const Shape& shape, const T& init_value, Node* node,
300                     Index* index);
301 
302   // Initialize node->children based on 'shape'. All children have
303   // default-constructed data values.
304   void InitChildren(const Shape& shape, Node* node, Index* index);
305 
306   // Returns the number of subshapes, including interior nodes, in shape.
307   int64 CountSubshapes(const Shape& shape);
308 
309   // Helpers for traversing the shape via ForEachElement. The helpers
310   // recursively traverse the subtree rooted at "index" (defined as in
311   // ShapeUtil::GetSubshape).
312   template <typename Fn>
313   static Status ForEachHelper(const Fn& func, const std::vector<Node>& nodes);
314   template <typename Fn>
315   static Status ForEachMutableHelper(const Fn& func, std::vector<Node>* nodes);
316 
317   // Return the tree node at the given index.
318   Node* Lookup(ShapeIndexView index);
319   const Node* Lookup(ShapeIndexView index) const;
320 
321   // The nodes in this shape tree.
322   std::vector<Node> nodes_;
323 
324   // Index table for node lookups.
325   std::vector<Index> index_table_;
326 
327   // If we own our Shape, this field contains it, and shape_ is a pointer into
328   // here.  Otherwise if we don't own our shape, this is nullptr.
329   std::shared_ptr<Shape> shape_storage_;
330 
331   // The XLA shape mirrored in this ShapeTree.  This is either
332   // shape_storage_.get() or the Shape pointer passed to our constructor.
333   const Shape* shape_;
334 };
335 
336 // Internal iterator that performs a pre-order walk. This is cheap to copy.
337 // The iterator value_type is equivalent to a
338 // std::pair<ShapeIndex,T>&, similar to std::map.
339 template <typename ContainerType, typename IteratorType, typename ValueType>
340 class ShapeTreeIterator
341     : public std::iterator<std::bidirectional_iterator_tag, ValueType> {
342  public:
ShapeTreeIterator(ContainerType * nodes,IteratorType node,bool iterate_leaves_only)343   ShapeTreeIterator(ContainerType* nodes, IteratorType node,
344                     bool iterate_leaves_only)
345       : nodes_(nodes),
346         node_(std::move(node)),
347         iterate_leaves_only_(iterate_leaves_only) {
348     while (iterate_leaves_only && node_ != nodes_->end() && !node_->is_leaf) {
349       ++node_;
350     }
351   }
352 
353   ShapeTreeIterator& operator++() {
354     ++node_;
355     while (iterate_leaves_only_ && node_ != nodes_->end() && !node_->is_leaf) {
356       ++node_;
357     }
358     return *this;
359   }
360   ShapeTreeIterator operator++(int) {
361     auto i = *this;
362     ++(*this);
363     return i;
364   }
365 
366   ShapeTreeIterator& operator--() {
367     --node_;
368     while (iterate_leaves_only_ && node_ > nodes_->begin() && !node_->is_leaf) {
369       --node_;
370     }
371     return *this;
372   }
373   ShapeTreeIterator operator--(int) {
374     auto i = *this;
375     --(*this);
376     return i;
377   }
378 
379   bool operator==(const ShapeTreeIterator& other) const {
380     return node_ == other.node_;
381   }
382   bool operator!=(const ShapeTreeIterator& other) const {
383     return node_ != other.node_;
384   }
385   ValueType& operator*() { return node_->data; }
386   ValueType* operator->() { return &node_->data; }
387 
388  private:
389   ContainerType* nodes_;
390   IteratorType node_;
391   // True if we should not include interior nodes in our walk.
392   const bool iterate_leaves_only_;
393 };
394 
395 template <typename T>
CountSubshapes(const Shape & shape)396 int64 ShapeTree<T>::CountSubshapes(const Shape& shape) {
397   int64 current_count = 1;
398   if (shape.IsTuple()) {
399     int64 count = ShapeUtil::TupleElementCount(shape);
400     for (int i = 0; i < count; ++i) {
401       current_count += CountSubshapes(shape.tuple_shapes(i));
402     }
403   }
404   return current_count;
405 }
406 
407 template <typename T>
InitChildren(const Shape & shape,const T & init_value,Node * node,Index * index)408 void ShapeTree<T>::InitChildren(const Shape& shape, const T& init_value,
409                                 Node* node, Index* index) {
410   if (shape.IsTuple()) {
411     const int64 size = ShapeUtil::TupleElementCount(shape);
412 #ifndef NDEBUG
413     index->children_count = size;
414 #endif
415     node->is_leaf = false;
416     ShapeIndex shape_index = node->data.first;
417     shape_index.push_back(0);
418 
419     // At the end of the index_table, reserve a continuous space to hold the
420     // children of current node. In order to enforce the invariant that all
421     // children of a given node are placed together, we need to do the
422     // reservation before we recurse into any of its children.
423     int64 children_start_position = index_table_.size();
424     index_table_.resize(index_table_.size() + size);
425 
426     for (int i = 0; i < size; ++i) {
427       shape_index[shape_index.size() - 1] = i;
428       index_table_[children_start_position + i].index = nodes_.size();
429       // The first child of the node in the index table is placed at the end of
430       // the table.
431       index_table_[children_start_position + i].children_start =
432           index_table_.size();
433       nodes_.emplace_back(shape_index, init_value);
434       InitChildren(shape.tuple_shapes(i), init_value, &nodes_.back(),
435                    &index_table_[children_start_position + i]);
436     }
437   } else {
438 #ifndef NDEBUG
439     index->children_count = 0;
440 #endif
441   }
442 }
443 
444 template <typename T>
InitChildren(const Shape & shape,Node * node,Index * index)445 void ShapeTree<T>::InitChildren(const Shape& shape, Node* node, Index* index) {
446   if (shape.IsTuple()) {
447     const int64 size = ShapeUtil::TupleElementCount(shape);
448 #ifndef NDEBUG
449     index->children_count = size;
450 #endif
451     node->is_leaf = false;
452     ShapeIndex shape_index = node->data.first;
453     shape_index.push_back(0);
454 
455     // At the end of the index_table, reserve a continuous space to hold the
456     // children of current node. In order to enforce the invariant that all
457     // children of a given node are placed together, we need to do the
458     // reservation before we recurse into any of its children.
459     int64 children_start_position = index_table_.size();
460     index_table_.resize(index_table_.size() + size);
461 
462     for (int i = 0; i < size; ++i) {
463       shape_index[shape_index.size() - 1] = i;
464       index_table_[children_start_position + i].index = nodes_.size();
465       // The first child of the node in the index table is placed at the end of
466       // the table.
467       index_table_[children_start_position + i].children_start =
468           index_table_.size();
469       nodes_.emplace_back(shape_index);
470       InitChildren(shape.tuple_shapes(i), &nodes_.back(),
471                    &index_table_[children_start_position + i]);
472     }
473   } else {
474 #ifndef NDEBUG
475     index->children_count = 0;
476 #endif
477   }
478 }
479 
480 template <typename T>
ShapeTree(Shape shape)481 ShapeTree<T>::ShapeTree(Shape shape)
482     : shape_storage_(std::make_shared<Shape>(std::move(shape))),
483       shape_(shape_storage_.get()) {
484   const int64 count = CountSubshapes(*shape_);
485   nodes_.reserve(count);
486   nodes_.emplace_back(ShapeIndex{});
487 
488   index_table_.reserve(count);
489   index_table_.emplace_back(Index{0, 1});
490   InitChildren(*shape_, &nodes_[0], &index_table_[0]);
491 }
492 
493 template <typename T>
ShapeTree(const Shape * shape)494 ShapeTree<T>::ShapeTree(const Shape* shape) : shape_(shape) {
495   const int64 count = CountSubshapes(*shape_);
496   nodes_.reserve(count);
497   nodes_.emplace_back(ShapeIndex{});
498 
499   index_table_.reserve(count);
500   index_table_.emplace_back(Index{0, 1});
501   InitChildren(*shape_, &nodes_[0], &index_table_[0]);
502 }
503 
504 template <typename T>
ShapeTree(const std::shared_ptr<Shape> & shape)505 ShapeTree<T>::ShapeTree(const std::shared_ptr<Shape>& shape)
506     : shape_storage_(shape), shape_(shape_storage_.get()) {
507   const int64 count = CountSubshapes(*shape_);
508   nodes_.reserve(count);
509   nodes_.emplace_back(ShapeIndex{});
510 
511   index_table_.reserve(count);
512   index_table_.emplace_back(Index{0, 1});
513   InitChildren(*shape_, &nodes_[0], &index_table_[0]);
514 }
515 
516 template <typename T>
ShapeTree(Shape shape,const T & init_value)517 ShapeTree<T>::ShapeTree(Shape shape, const T& init_value)
518     : shape_storage_(std::make_shared<Shape>(std::move(shape))),
519       shape_(shape_storage_.get()) {
520   const int64 count = CountSubshapes(*shape_);
521   nodes_.reserve(count);
522   nodes_.emplace_back(ShapeIndex{}, init_value);
523 
524   index_table_.reserve(count);
525   index_table_.emplace_back(Index{0, 1});
526   InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]);
527 }
528 
529 template <typename T>
ShapeTree(const Shape * shape,const T & init_value)530 ShapeTree<T>::ShapeTree(const Shape* shape, const T& init_value)
531     : shape_(shape) {
532   const int64 count = CountSubshapes(*shape_);
533   nodes_.reserve(count);
534   nodes_.emplace_back(ShapeIndex{}, init_value);
535 
536   index_table_.reserve(count);
537   index_table_.emplace_back(Index{0, 1});
538   InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]);
539 }
540 
541 template <typename T>
ShapeTree(const std::shared_ptr<Shape> & shape,const T & init_value)542 ShapeTree<T>::ShapeTree(const std::shared_ptr<Shape>& shape,
543                         const T& init_value)
544     : shape_storage_(shape), shape_(shape_storage_.get()) {
545   const int64 count = CountSubshapes(*shape_);
546   nodes_.reserve(count);
547   nodes_.emplace_back(ShapeIndex{}, init_value);
548 
549   index_table_.reserve(count);
550   index_table_.emplace_back(Index{0, 1});
551   InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]);
552 }
553 
554 template <typename T>
element(ShapeIndexView index)555 const T& ShapeTree<T>::element(ShapeIndexView index) const {
556   return Lookup(index)->data.second;
557 }
558 
559 template <typename T>
mutable_element(ShapeIndexView index)560 T* ShapeTree<T>::mutable_element(ShapeIndexView index) {
561   return &Lookup(index)->data.second;
562 }
563 
564 template <typename T>
Lookup(ShapeIndexView index)565 internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(ShapeIndexView index) {
566   Index* iter = &index_table_[0];
567   for (const int64 i : index) {
568     CHECK_GE(i, 0);
569 #ifndef NDEBUG
570     CHECK_LT(i, iter->children_count);
571 #endif
572     iter = &index_table_[iter->children_start + i];
573   }
574 
575   return &nodes_[iter->index];
576 }
577 
578 template <typename T>
Lookup(ShapeIndexView index)579 const internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(
580     ShapeIndexView index) const {
581   return const_cast<ShapeTree*>(this)->Lookup(index);
582 }
583 
584 /* static */
585 template <typename T>
586 template <typename Fn>
ForEachHelper(const Fn & func,const std::vector<Node> & nodes)587 Status ShapeTree<T>::ForEachHelper(const Fn& func,
588                                    const std::vector<Node>& nodes) {
589   for (const auto& node : nodes) {
590     TF_RETURN_IF_ERROR(func(node.data.first, node.data.second));
591   }
592   return Status::OK();
593 }
594 
595 /* static */
596 template <typename T>
597 template <typename Fn>
ForEachMutableHelper(const Fn & func,std::vector<Node> * nodes)598 Status ShapeTree<T>::ForEachMutableHelper(const Fn& func,
599                                           std::vector<Node>* nodes) {
600   for (auto& node : *nodes) {
601     TF_RETURN_IF_ERROR(func(node.data.first, &node.data.second));
602   }
603   return Status::OK();
604 }
605 
606 template <typename T>
607 template <typename Fn>
ForEachElementWithStatus(const Fn & func)608 Status ShapeTree<T>::ForEachElementWithStatus(const Fn& func) const {
609   return ForEachHelper(func, nodes_);
610 }
611 
612 template <typename T>
613 template <typename Fn>
ForEachMutableElementWithStatus(const Fn & func)614 Status ShapeTree<T>::ForEachMutableElementWithStatus(const Fn& func) {
615   return ForEachMutableHelper(func, &nodes_);
616 }
617 
618 template <typename T>
619 template <typename Fn>
ForEachElement(const Fn & func)620 void ShapeTree<T>::ForEachElement(const Fn& func) const {
621   return ForEachHelper(
622              [&func](const ShapeIndex& index, const T& data) {
623                func(index, data);
624                return Status::OK();
625              },
626              nodes_)
627       .IgnoreError();
628 }
629 
630 template <typename T>
631 template <typename Fn>
ForEachMutableElement(const Fn & func)632 void ShapeTree<T>::ForEachMutableElement(const Fn& func) {
633   return ForEachMutableHelper(
634              [&func](const ShapeIndex& index, T* data) {
635                func(index, data);
636                return Status::OK();
637              },
638              &nodes_)
639       .IgnoreError();
640 }
641 
642 template <typename T>
CopySubtreeFrom(const ShapeTree<T> & other,const ShapeIndex & source_base_index,const ShapeIndex & target_base_index)643 void ShapeTree<T>::CopySubtreeFrom(const ShapeTree<T>& other,
644                                    const ShapeIndex& source_base_index,
645                                    const ShapeIndex& target_base_index) {
646   CHECK(ShapeUtil::Compatible(
647       ShapeUtil::GetSubshape(shape(), target_base_index),
648       ShapeUtil::GetSubshape(other.shape(), source_base_index)));
649   ForEachMutableElement([this, &other, &source_base_index, &target_base_index](
650                             const ShapeIndex& index, T* data) {
651     // Copy the data element only if index is in the
652     // subtree rooted at target_base_index.
653     for (int i = 0; i < target_base_index.size(); ++i) {
654       if (i >= index.size() || index[i] != target_base_index[i]) {
655         return;
656       }
657     }
658     // Construct source element index to copy from.
659     ShapeIndex source_index = source_base_index;
660     for (int i = target_base_index.size(); i < index.size(); ++i) {
661       source_index.push_back(index[i]);
662     }
663     *data = other.element(source_index);
664   });
665 }
666 
667 template <typename T>
668 bool ShapeTree<T>::operator==(const ShapeTree<T>& other) const {
669   bool equal = true;
670   ForEachElement([&other, &equal](const ShapeIndex& index, const T& data) {
671     if (data != other.element(index)) {
672       equal = false;
673     }
674   });
675   return equal;
676 }
677 
678 }  // namespace xla
679 
680 #endif  // TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
681