• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include "tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h"
16 
17 namespace tensorflow {
18 namespace boosted_trees {
19 namespace utils {
20 
21 using Iterator = ExamplesIterable::Iterator;
22 
ExamplesIterable(const std::vector<Tensor> & dense_float_feature_columns,const std::vector<sparse::SparseTensor> & sparse_float_feature_columns,const std::vector<sparse::SparseTensor> & sparse_int_feature_columns,int64 example_start,int64 example_end)23 ExamplesIterable::ExamplesIterable(
24     const std::vector<Tensor>& dense_float_feature_columns,
25     const std::vector<sparse::SparseTensor>& sparse_float_feature_columns,
26     const std::vector<sparse::SparseTensor>& sparse_int_feature_columns,
27     int64 example_start, int64 example_end)
28     : example_start_(example_start), example_end_(example_end) {
29   // Create dense float column values.
30   dense_float_column_values_.reserve(dense_float_feature_columns.size());
31   for (auto& dense_float_column : dense_float_feature_columns) {
32     dense_float_column_values_.emplace_back(
33         dense_float_column.template matrix<float>());
34   }
35 
36   // Create sparse float column iterables and values.
37   sparse_float_column_iterables_.reserve(sparse_float_feature_columns.size());
38   sparse_float_column_values_.reserve(sparse_float_feature_columns.size());
39   sparse_float_dimensions_.reserve(sparse_float_feature_columns.size());
40   for (auto& sparse_float_column : sparse_float_feature_columns) {
41     sparse_float_column_iterables_.emplace_back(
42         sparse_float_column.indices().template matrix<int64>(), example_start,
43         example_end);
44     sparse_float_column_values_.emplace_back(
45         sparse_float_column.values().template vec<float>());
46     sparse_float_dimensions_.push_back(sparse_float_column.shape()[1]);
47   }
48 
49   // Create sparse int column iterables and values.
50   sparse_int_column_iterables_.reserve(sparse_int_feature_columns.size());
51   sparse_int_column_values_.reserve(sparse_int_feature_columns.size());
52   for (auto& sparse_int_column : sparse_int_feature_columns) {
53     sparse_int_column_iterables_.emplace_back(
54         sparse_int_column.indices().template matrix<int64>(), example_start,
55         example_end);
56     sparse_int_column_values_.emplace_back(
57         sparse_int_column.values().template vec<int64>());
58   }
59 }
60 
Iterator(ExamplesIterable * iter,int64 example_idx)61 Iterator::Iterator(ExamplesIterable* iter, int64 example_idx)
62     : iter_(iter), example_idx_(example_idx) {
63   // Create sparse iterators.
64   sparse_float_column_iterators_.reserve(
65       iter->sparse_float_column_iterables_.size());
66   for (auto& iterable : iter->sparse_float_column_iterables_) {
67     sparse_float_column_iterators_.emplace_back(iterable.begin());
68   }
69   sparse_int_column_iterators_.reserve(
70       iter->sparse_int_column_iterables_.size());
71   for (auto& iterable : iter->sparse_int_column_iterables_) {
72     sparse_int_column_iterators_.emplace_back(iterable.begin());
73   }
74 
75   // Pre-size example features.
76   example_.dense_float_features.resize(
77       iter_->dense_float_column_values_.size());
78   example_.sparse_int_features.resize(iter_->sparse_int_column_values_.size());
79   example_.sparse_float_features.resize(
80       iter_->sparse_float_column_values_.size());
81 }
82 
83 }  // namespace utils
84 }  // namespace boosted_trees
85 }  // namespace tensorflow
86