• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_BATCH_FEATURES_H_
16 #define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_BATCH_FEATURES_H_
17 
18 #include <vector>
19 #include "tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h"
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/framework/tensor_types.h"
23 #include "tensorflow/core/platform/macros.h"
24 #include "tensorflow/core/util/sparse/sparse_tensor.h"
25 
26 namespace tensorflow {
27 namespace boosted_trees {
28 namespace utils {
29 
30 class BatchFeatures {
31  public:
32   // Constructs batch features with a fixed batch size.
BatchFeatures(int64 batch_size)33   explicit BatchFeatures(int64 batch_size) : batch_size_(batch_size) {}
34 
35   // Disallow copy and assign.
36   BatchFeatures(const BatchFeatures& other) = delete;
37   BatchFeatures& operator=(const BatchFeatures& other) = delete;
38 
39   // Method to initialize batch features from op kernel context.
40   Status Initialize(std::vector<Tensor> dense_float_features_list,
41                     std::vector<Tensor> sparse_float_feature_indices_list,
42                     std::vector<Tensor> sparse_float_feature_values_list,
43                     std::vector<Tensor> sparse_float_feature_shapes_list,
44                     std::vector<Tensor> sparse_int_feature_indices_list,
45                     std::vector<Tensor> sparse_int_feature_values_list,
46                     std::vector<Tensor> sparse_int_feature_shapes_list);
47 
GetFeatureColumnSizes(int64 * const num_dense_float_features,int64 * const num_sparse_float_features,int64 * const num_sparse_int_features)48   Status GetFeatureColumnSizes(int64* const num_dense_float_features,
49                                int64* const num_sparse_float_features,
50                                int64* const num_sparse_int_features) const {
51     QCHECK_NE(num_dense_float_features, nullptr);
52     QCHECK_NE(num_sparse_float_features, nullptr);
53     QCHECK_NE(num_sparse_int_features, nullptr);
54     *num_dense_float_features = dense_float_feature_columns_.size();
55     *num_sparse_float_features = sparse_float_feature_columns_.size();
56     *num_sparse_int_features = sparse_int_feature_columns_.size();
57     if (*num_dense_float_features == 0 && *num_sparse_float_features == 0 &&
58         *num_sparse_int_features == 0) {
59       return errors::FailedPrecondition("Not initialized yet.");
60     }
61     return Status::OK();
62   }
63 
64   // Creates an example iterable for the requested slice.
examples_iterable(int64 example_start,int64 example_end)65   ExamplesIterable examples_iterable(int64 example_start,
66                                      int64 example_end) const {
67     QCHECK(example_start >= 0 && example_end >= 0);
68     QCHECK(example_start < batch_size_ && example_end <= batch_size_);
69     return ExamplesIterable(
70         dense_float_feature_columns_, sparse_float_feature_columns_,
71         sparse_int_feature_columns_, example_start, example_end);
72   }
73 
74   // Returns the fixed batch size.
batch_size()75   int64 batch_size() const { return batch_size_; }
76 
77  private:
78   // Total number of examples in the batch.
79   const int64 batch_size_;
80 
81   // Dense float feature columns.
82   std::vector<Tensor> dense_float_feature_columns_;
83 
84   // Sparse float feature columns.
85   std::vector<sparse::SparseTensor> sparse_float_feature_columns_;
86 
87   // Sparse int feature columns.
88   std::vector<sparse::SparseTensor> sparse_int_feature_columns_;
89 };
90 
91 }  // namespace utils
92 }  // namespace boosted_trees
93 }  // namespace tensorflow
94 
95 #endif  // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_BATCH_FEATURES_H_
96