1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_UTIL_SPARSE_GROUP_ITERATOR_H_
17 #define TENSORFLOW_CORE_UTIL_SPARSE_GROUP_ITERATOR_H_
18
19 #include <vector>
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/lib/core/status.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/types.h"
25
26 namespace tensorflow {
27 namespace sparse {
28
29 class GroupIterable; // Predeclare GroupIterable for Group.
30
31 // This class is returned when dereferencing a GroupIterable iterator.
32 // It provides the methods group(), indices(), and values(), which
33 // provide access into the underlying SparseTensor.
34 class Group {
35 public:
Group(GroupIterable * iter,int64 loc,int64 next_loc)36 Group(GroupIterable* iter, int64 loc, int64 next_loc)
37 : iter_(iter), loc_(loc), next_loc_(next_loc) {}
38
39 std::vector<int64> group() const;
40 TTypes<int64>::UnalignedConstMatrix indices() const;
41 template <typename T>
42 typename TTypes<T>::UnalignedVec values() const;
43
44 private:
45 GroupIterable* iter_;
46 int64 loc_;
47 int64 next_loc_;
48 };
49
50 /////////////////
51 // GroupIterable
52 /////////////////
53 //
54 // Returned when calling sparse_tensor.group({dim0, dim1, ...}).
55 //
56 // Please note: the sparse_tensor should already be ordered according
57 // to {dim0, dim1, ...}. Otherwise this iteration will return invalid groups.
58 //
59 // Allows grouping and iteration of the SparseTensor according to the
60 // subset of dimensions provided to the group call.
61 //
62 // The actual grouping dimensions are stored in the
63 // internal vector group_dims_. Iterators inside the iterable provide
64 // the three methods:
65 //
66 // * group(): returns a vector with the current group dimension values.
67 // * indices(): a map of index, providing the indices in
68 // this group.
69 // * values(): a map of values, providing the values in
70 // this group.
71 //
72 // To iterate across GroupIterable, see examples in README.md.
73 //
74
75 // Forward declaration of SparseTensor
76 class GroupIterable {
77 public:
78 typedef gtl::ArraySlice<int64> VarDimArray;
79
GroupIterable(Tensor ix,Tensor vals,int dims,const VarDimArray & group_dims)80 GroupIterable(Tensor ix, Tensor vals, int dims, const VarDimArray& group_dims)
81 : ix_(ix),
82 ix_matrix_(ix_.matrix<int64>()),
83 vals_(vals),
84 dims_(dims),
85 group_dims_(group_dims.begin(), group_dims.end()) {}
86
87 class IteratorStep;
88
begin()89 IteratorStep begin() { return IteratorStep(this, 0); }
at(int64 loc)90 IteratorStep at(int64 loc) {
91 CHECK(loc >= 0 && loc <= ix_.dim_size(0))
92 << "loc provided must lie between 0 and " << ix_.dim_size(0);
93 return IteratorStep(this, loc);
94 }
end()95 IteratorStep end() { return IteratorStep(this, ix_.dim_size(0)); }
96
97 template <typename TIX>
GroupMatches(const TIX & ix,int64 loc_a,int64 loc_b)98 inline bool GroupMatches(const TIX& ix, int64 loc_a, int64 loc_b) const {
99 bool matches = true;
100 for (int d : group_dims_) {
101 if (ix(loc_a, d) != ix(loc_b, d)) {
102 matches = false;
103 }
104 }
105 return matches;
106 }
107
108 class IteratorStep {
109 public:
IteratorStep(GroupIterable * iter,int64 loc)110 IteratorStep(GroupIterable* iter, int64 loc)
111 : iter_(iter), loc_(loc), next_loc_(loc_) {
112 UpdateEndOfGroup();
113 }
114
115 void UpdateEndOfGroup();
116 bool operator!=(const IteratorStep& rhs) const;
117 bool operator==(const IteratorStep& rhs) const;
118 IteratorStep& operator++(); // prefix ++
119 IteratorStep operator++(int); // postfix ++
120 Group operator*() const { return Group(iter_, loc_, next_loc_); }
loc()121 int64 loc() const { return loc_; }
122
123 private:
124 GroupIterable* iter_;
125 int64 loc_;
126 int64 next_loc_;
127 };
128
129 private:
130 friend class Group;
131 const Tensor ix_;
132 const TTypes<int64>::ConstMatrix ix_matrix_;
133 Tensor vals_;
134 const int dims_;
135 const gtl::InlinedVector<int64, 8> group_dims_;
136 };
137
138 // Implementation of Group::values<T>()
139 template <typename T>
values()140 typename TTypes<T>::UnalignedVec Group::values() const {
141 return typename TTypes<T>::UnalignedVec(&(iter_->vals_.vec<T>()(loc_)),
142 next_loc_ - loc_);
143 }
144
145 } // namespace sparse
146 } // namespace tensorflow
147
148 #endif // TENSORFLOW_CORE_UTIL_SPARSE_GROUP_ITERATOR_H_
149