• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_UTIL_SPARSE_GROUP_ITERATOR_H_
17 #define TENSORFLOW_CORE_UTIL_SPARSE_GROUP_ITERATOR_H_
18 
19 #include <vector>
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/lib/core/status.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/types.h"
25 
26 namespace tensorflow {
27 namespace sparse {
28 
29 class GroupIterable;  // Predeclare GroupIterable for Group.
30 
31 // This class is returned when dereferencing a GroupIterable iterator.
32 // It provides the methods group(), indices(), and values(), which
33 // provide access into the underlying SparseTensor.
34 class Group {
35  public:
Group(GroupIterable * iter,int64 loc,int64 next_loc)36   Group(GroupIterable* iter, int64 loc, int64 next_loc)
37       : iter_(iter), loc_(loc), next_loc_(next_loc) {}
38 
39   std::vector<int64> group() const;
40   TTypes<int64>::UnalignedConstMatrix indices() const;
41   template <typename T>
42   typename TTypes<T>::UnalignedVec values() const;
43 
44  private:
45   GroupIterable* iter_;
46   int64 loc_;
47   int64 next_loc_;
48 };
49 
50 /////////////////
51 // GroupIterable
52 /////////////////
53 //
54 // Returned when calling sparse_tensor.group({dim0, dim1, ...}).
55 //
56 // Please note: the sparse_tensor should already be ordered according
57 // to {dim0, dim1, ...}.  Otherwise this iteration will return invalid groups.
58 //
59 // Allows grouping and iteration of the SparseTensor according to the
60 // subset of dimensions provided to the group call.
61 //
62 // The actual grouping dimensions are stored in the
63 // internal vector group_dims_.  Iterators inside the iterable provide
64 // the three methods:
65 //
66 // *  group(): returns a vector with the current group dimension values.
67 // *  indices(): a map of index, providing the indices in
68 //    this group.
69 // *  values(): a map of values, providing the values in
70 //    this group.
71 //
72 // To iterate across GroupIterable, see examples in README.md.
73 //
74 
75 // Forward declaration of SparseTensor
76 class GroupIterable {
77  public:
78   typedef gtl::ArraySlice<int64> VarDimArray;
79 
GroupIterable(Tensor ix,Tensor vals,int dims,const VarDimArray & group_dims)80   GroupIterable(Tensor ix, Tensor vals, int dims, const VarDimArray& group_dims)
81       : ix_(ix),
82         ix_matrix_(ix_.matrix<int64>()),
83         vals_(vals),
84         dims_(dims),
85         group_dims_(group_dims.begin(), group_dims.end()) {}
86 
87   class IteratorStep;
88 
begin()89   IteratorStep begin() { return IteratorStep(this, 0); }
at(int64 loc)90   IteratorStep at(int64 loc) {
91     CHECK(loc >= 0 && loc <= ix_.dim_size(0))
92         << "loc provided must lie between 0 and " << ix_.dim_size(0);
93     return IteratorStep(this, loc);
94   }
end()95   IteratorStep end() { return IteratorStep(this, ix_.dim_size(0)); }
96 
97   template <typename TIX>
GroupMatches(const TIX & ix,int64 loc_a,int64 loc_b)98   inline bool GroupMatches(const TIX& ix, int64 loc_a, int64 loc_b) const {
99     bool matches = true;
100     for (int d : group_dims_) {
101       if (ix(loc_a, d) != ix(loc_b, d)) {
102         matches = false;
103       }
104     }
105     return matches;
106   }
107 
108   class IteratorStep {
109    public:
IteratorStep(GroupIterable * iter,int64 loc)110     IteratorStep(GroupIterable* iter, int64 loc)
111         : iter_(iter), loc_(loc), next_loc_(loc_) {
112       UpdateEndOfGroup();
113     }
114 
115     void UpdateEndOfGroup();
116     bool operator!=(const IteratorStep& rhs) const;
117     bool operator==(const IteratorStep& rhs) const;
118     IteratorStep& operator++();    // prefix ++
119     IteratorStep operator++(int);  // postfix ++
120     Group operator*() const { return Group(iter_, loc_, next_loc_); }
loc()121     int64 loc() const { return loc_; }
122 
123    private:
124     GroupIterable* iter_;
125     int64 loc_;
126     int64 next_loc_;
127   };
128 
129  private:
130   friend class Group;
131   const Tensor ix_;
132   const TTypes<int64>::ConstMatrix ix_matrix_;
133   Tensor vals_;
134   const int dims_;
135   const gtl::InlinedVector<int64, 8> group_dims_;
136 };
137 
138 // Implementation of Group::values<T>()
139 template <typename T>
values()140 typename TTypes<T>::UnalignedVec Group::values() const {
141   return typename TTypes<T>::UnalignedVec(&(iter_->vals_.vec<T>()(loc_)),
142                                           next_loc_ - loc_);
143 }
144 
145 }  // namespace sparse
146 }  // namespace tensorflow
147 
148 #endif  // TENSORFLOW_CORE_UTIL_SPARSE_GROUP_ITERATOR_H_
149