• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 
16 #ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLES_ITERABLE_H_
17 #define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLES_ITERABLE_H_
18 
19 #include <vector>
20 
21 #include "tensorflow/contrib/boosted_trees/lib/utils/example.h"
22 #include "tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/util/sparse/sparse_tensor.h"
25 
26 namespace tensorflow {
27 namespace boosted_trees {
28 namespace utils {
29 
30 // Enables row-wise iteration through examples from feature columns.
31 class ExamplesIterable {
32  public:
33   // Constructs an iterable given the desired examples slice and corresponding
34   // feature columns.
35   ExamplesIterable(
36       const std::vector<Tensor>& dense_float_feature_columns,
37       const std::vector<sparse::SparseTensor>& sparse_float_feature_columns,
38       const std::vector<sparse::SparseTensor>& sparse_int_feature_columns,
39       int64 example_start, int64 example_end);
40 
41   // Helper class to iterate through examples.
42   class Iterator {
43    public:
44     Iterator(ExamplesIterable* iter, int64 example_idx);
45 
46     Iterator& operator++() {
47       // Advance to next example.
48       ++example_idx_;
49 
50       // Update sparse column iterables.
51       for (auto& it : sparse_float_column_iterators_) {
52         ++it;
53       }
54       for (auto& it : sparse_int_column_iterators_) {
55         ++it;
56       }
57       return (*this);
58     }
59 
60     Iterator operator++(int) {
61       Iterator tmp(*this);
62       ++(*this);
63       return tmp;
64     }
65 
66     bool operator!=(const Iterator& other) const {
67       QCHECK_EQ(iter_, other.iter_);
68       return (example_idx_ != other.example_idx_);
69     }
70 
71     bool operator==(const Iterator& other) const {
72       QCHECK_EQ(iter_, other.iter_);
73       return (example_idx_ == other.example_idx_);
74     }
75 
76     const Example& operator*() {
77       // Set example index based on iterator.
78       example_.example_idx = example_idx_;
79 
80       // Get dense float values per column.
81       auto& dense_float_features = example_.dense_float_features;
82       for (size_t dense_float_idx = 0;
83            dense_float_idx < dense_float_features.size(); ++dense_float_idx) {
84         dense_float_features[dense_float_idx] =
85             iter_->dense_float_column_values_[dense_float_idx](example_idx_, 0);
86       }
87 
88       // Get sparse float values per column.
89       auto& sparse_float_features = example_.sparse_float_features;
90       // Iterate through each sparse float feature column.
91       for (size_t sparse_float_idx = 0;
92            sparse_float_idx < iter_->sparse_float_column_iterables_.size();
93            ++sparse_float_idx) {
94         // Clear info from a previous instance.
95         sparse_float_features[sparse_float_idx].Clear();
96 
97         // Get range for values tensor.
98         const auto& row_range =
99             (*sparse_float_column_iterators_[sparse_float_idx]);
100         DCHECK_EQ(example_idx_, row_range.example_idx);
101 
102         // If the example has this feature column.
103         if (row_range.start < row_range.end) {
104           const int32 dimension =
105               iter_->sparse_float_dimensions_[sparse_float_idx];
106           sparse_float_features[sparse_float_idx].SetDimension(dimension);
107           if (dimension <= 1) {
108             // single dimensional sparse feature column.
109             DCHECK_EQ(1, row_range.end - row_range.start);
110             sparse_float_features[sparse_float_idx].Add(
111                 0, iter_->sparse_float_column_values_[sparse_float_idx](
112                        row_range.start));
113           } else {
114             // Retrieve original indices tensor.
115             const TTypes<int64>::ConstMatrix& indices =
116                 iter_->sparse_float_column_iterables_[sparse_float_idx]
117                     .sparse_indices();
118 
119             sparse_float_features[sparse_float_idx].Reserve(row_range.end -
120                                                             row_range.start);
121 
122             // For each value.
123             for (int64 row_idx = row_range.start; row_idx < row_range.end;
124                  ++row_idx) {
125               // Get the feature id for the feature column and the value.
126               const int32 feature_id = indices(row_idx, 1);
127               DCHECK_EQ(example_idx_, indices(row_idx, 0));
128 
129               // Save the value to our sparse matrix.
130               sparse_float_features[sparse_float_idx].Add(
131                   feature_id,
132                   iter_->sparse_float_column_values_[sparse_float_idx](
133                       row_idx));
134             }
135           }
136         }
137       }
138 
139       // Get sparse int values per column.
140       auto& sparse_int_features = example_.sparse_int_features;
141       for (size_t sparse_int_idx = 0;
142            sparse_int_idx < sparse_int_features.size(); ++sparse_int_idx) {
143         const auto& row_range = (*sparse_int_column_iterators_[sparse_int_idx]);
144         DCHECK_EQ(example_idx_, row_range.example_idx);
145         sparse_int_features[sparse_int_idx].clear();
146         if (row_range.start < row_range.end) {
147           sparse_int_features[sparse_int_idx].reserve(row_range.end -
148                                                       row_range.start);
149           for (int64 row_idx = row_range.start; row_idx < row_range.end;
150                ++row_idx) {
151             sparse_int_features[sparse_int_idx].push_back(
152                 iter_->sparse_int_column_values_[sparse_int_idx](row_idx));
153           }
154         }
155       }
156 
157       return example_;
158     }
159 
160    private:
161     // Examples iterable (not owned).
162     const ExamplesIterable* iter_;
163 
164     // Example index.
165     int64 example_idx_;
166 
167     // Sparse float column iterators.
168     std::vector<SparseColumnIterable::Iterator> sparse_float_column_iterators_;
169 
170     // Sparse int column iterators.
171     std::vector<SparseColumnIterable::Iterator> sparse_int_column_iterators_;
172 
173     // Example placeholder.
174     Example example_;
175   };
176 
begin()177   Iterator begin() { return Iterator(this, example_start_); }
end()178   Iterator end() { return Iterator(this, example_end_); }
179 
180  private:
181   // Example slice spec.
182   const int64 example_start_;
183   const int64 example_end_;
184 
185   // Dense float column values.
186   std::vector<TTypes<float>::ConstMatrix> dense_float_column_values_;
187 
188   // Sparse float column iterables.
189   std::vector<SparseColumnIterable> sparse_float_column_iterables_;
190 
191   // Sparse float column values.
192   std::vector<TTypes<float>::ConstVec> sparse_float_column_values_;
193 
194   // Dimensions for sparse float feature columns.
195   std::vector<int32> sparse_float_dimensions_;
196 
197   // Sparse int column iterables.
198   std::vector<SparseColumnIterable> sparse_int_column_iterables_;
199 
200   // Sparse int column values.
201   std::vector<TTypes<int64>::ConstVec> sparse_int_column_values_;
202 };
203 
204 }  // namespace utils
205 }  // namespace boosted_trees
206 }  // namespace tensorflow
207 
208 #endif  // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLES_ITERABLE_H_
209