• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 
16 #ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_SPARSE_COLUMN_ITERABLE_H_
17 #define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_SPARSE_COLUMN_ITERABLE_H_
18 
19 #include "tensorflow/core/framework/tensor.h"
20 #include "tensorflow/core/framework/tensor_types.h"
21 #include "tensorflow/core/platform/logging.h"
22 #include "tensorflow/core/platform/types.h"
23 
24 namespace tensorflow {
25 namespace boosted_trees {
26 namespace utils {
27 
28 // Enables row-wise iteration through examples on sparse feature columns.
29 class SparseColumnIterable {
30  public:
31   // Indicates a contiguous range for an example: [start, end).
32   struct ExampleRowRange {
33     int64 example_idx;
34     int64 start;
35     int64 end;
36   };
37 
38   // Helper class to iterate through examples and return the corresponding
39   // indices row range. Note that the row range can be empty in case a given
40   // example has no corresponding indices.
41   // An Iterator can be initialized from any example start offset, the
42   // corresponding range indicators will be initialized in log time.
43   class Iterator {
44    public:
45     Iterator(SparseColumnIterable* iter, int64 example_idx);
46 
47     Iterator& operator++() {
48       ++example_idx_;
49       if (cur_ < end_ && iter_->ix()(cur_, 0) < example_idx_) {
50         cur_ = next_;
51         UpdateNext();
52       }
53       return (*this);
54     }
55 
56     Iterator operator++(int) {
57       Iterator tmp(*this);
58       ++(*this);
59       return tmp;
60     }
61 
62     bool operator!=(const Iterator& other) const {
63       QCHECK_EQ(iter_, other.iter_);
64       return (example_idx_ != other.example_idx_);
65     }
66 
67     bool operator==(const Iterator& other) const {
68       QCHECK_EQ(iter_, other.iter_);
69       return (example_idx_ == other.example_idx_);
70     }
71 
72     const ExampleRowRange& operator*() {
73       range_.example_idx = example_idx_;
74       if (cur_ < end_ && iter_->ix()(cur_, 0) == example_idx_) {
75         range_.start = cur_;
76         range_.end = next_;
77       } else {
78         range_.start = 0;
79         range_.end = 0;
80       }
81       return range_;
82     }
83 
84    private:
UpdateNext()85     void UpdateNext() {
86       next_ = std::min(next_ + 1, end_);
87       while (next_ < end_ && iter_->ix()(cur_, 0) == iter_->ix()(next_, 0)) {
88         ++next_;
89       }
90     }
91 
92     const SparseColumnIterable* iter_;
93     int64 example_idx_;
94     int64 cur_;
95     int64 next_;
96     const int64 end_;
97     ExampleRowRange range_;
98   };
99 
100   // Constructs an iterable given the desired examples slice and corresponding
101   // feature columns.
SparseColumnIterable(TTypes<int64>::ConstMatrix ix,int64 example_start,int64 example_end)102   SparseColumnIterable(TTypes<int64>::ConstMatrix ix, int64 example_start,
103                        int64 example_end)
104       : ix_(ix), example_start_(example_start), example_end_(example_end) {
105     QCHECK(example_start >= 0 && example_end >= 0);
106   }
107 
begin()108   Iterator begin() { return Iterator(this, example_start_); }
end()109   Iterator end() { return Iterator(this, example_end_); }
110 
ix()111   const TTypes<int64>::ConstMatrix& ix() const { return ix_; }
example_start()112   int64 example_start() const { return example_start_; }
example_end()113   int64 example_end() const { return example_end_; }
114 
sparse_indices()115   const TTypes<int64>::ConstMatrix& sparse_indices() const { return ix_; }
116 
117  private:
118   // Sparse indices matrix.
119   TTypes<int64>::ConstMatrix ix_;
120 
121   // Example slice spec.
122   const int64 example_start_;
123   const int64 example_end_;
124 };
125 
126 }  // namespace utils
127 }  // namespace boosted_trees
128 }  // namespace tensorflow
129 
130 #endif  // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_SPARSE_COLUMN_ITERABLE_H_
131