1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include "tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h"
16
17 namespace tensorflow {
18 namespace boosted_trees {
19 namespace utils {
20
21 using Iterator = ExamplesIterable::Iterator;
22
ExamplesIterable(const std::vector<Tensor> & dense_float_feature_columns,const std::vector<sparse::SparseTensor> & sparse_float_feature_columns,const std::vector<sparse::SparseTensor> & sparse_int_feature_columns,int64 example_start,int64 example_end)23 ExamplesIterable::ExamplesIterable(
24 const std::vector<Tensor>& dense_float_feature_columns,
25 const std::vector<sparse::SparseTensor>& sparse_float_feature_columns,
26 const std::vector<sparse::SparseTensor>& sparse_int_feature_columns,
27 int64 example_start, int64 example_end)
28 : example_start_(example_start), example_end_(example_end) {
29 // Create dense float column values.
30 dense_float_column_values_.reserve(dense_float_feature_columns.size());
31 for (auto& dense_float_column : dense_float_feature_columns) {
32 dense_float_column_values_.emplace_back(
33 dense_float_column.template matrix<float>());
34 }
35
36 // Create sparse float column iterables and values.
37 sparse_float_column_iterables_.reserve(sparse_float_feature_columns.size());
38 sparse_float_column_values_.reserve(sparse_float_feature_columns.size());
39 sparse_float_dimensions_.reserve(sparse_float_feature_columns.size());
40 for (auto& sparse_float_column : sparse_float_feature_columns) {
41 sparse_float_column_iterables_.emplace_back(
42 sparse_float_column.indices().template matrix<int64>(), example_start,
43 example_end);
44 sparse_float_column_values_.emplace_back(
45 sparse_float_column.values().template vec<float>());
46 sparse_float_dimensions_.push_back(sparse_float_column.shape()[1]);
47 }
48
49 // Create sparse int column iterables and values.
50 sparse_int_column_iterables_.reserve(sparse_int_feature_columns.size());
51 sparse_int_column_values_.reserve(sparse_int_feature_columns.size());
52 for (auto& sparse_int_column : sparse_int_feature_columns) {
53 sparse_int_column_iterables_.emplace_back(
54 sparse_int_column.indices().template matrix<int64>(), example_start,
55 example_end);
56 sparse_int_column_values_.emplace_back(
57 sparse_int_column.values().template vec<int64>());
58 }
59 }
60
Iterator(ExamplesIterable * iter,int64 example_idx)61 Iterator::Iterator(ExamplesIterable* iter, int64 example_idx)
62 : iter_(iter), example_idx_(example_idx) {
63 // Create sparse iterators.
64 sparse_float_column_iterators_.reserve(
65 iter->sparse_float_column_iterables_.size());
66 for (auto& iterable : iter->sparse_float_column_iterables_) {
67 sparse_float_column_iterators_.emplace_back(iterable.begin());
68 }
69 sparse_int_column_iterators_.reserve(
70 iter->sparse_int_column_iterables_.size());
71 for (auto& iterable : iter->sparse_int_column_iterables_) {
72 sparse_int_column_iterators_.emplace_back(iterable.begin());
73 }
74
75 // Pre-size example features.
76 example_.dense_float_features.resize(
77 iter_->dense_float_column_values_.size());
78 example_.sparse_int_features.resize(iter_->sparse_int_column_values_.size());
79 example_.sparse_float_features.resize(
80 iter_->sparse_float_column_values_.size());
81 }
82
83 } // namespace utils
84 } // namespace boosted_trees
85 } // namespace tensorflow
86