• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 
16 #include "tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h"
17 #include "tensorflow/contrib/boosted_trees/lib/utils/macros.h"
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/framework/tensor.h"
20 
21 namespace tensorflow {
22 namespace boosted_trees {
23 namespace utils {
24 
OpInputListToTensorVec(const OpInputList & input_list)25 std::vector<Tensor> TensorUtils::OpInputListToTensorVec(
26     const OpInputList& input_list) {
27   std::vector<Tensor> tensor_vec;
28   tensor_vec.reserve(input_list.size());
29   for (const Tensor& tensor : input_list) {
30     tensor_vec.emplace_back(tensor);
31   }
32   return tensor_vec;
33 }
34 
ReadDenseFloatFeatures(OpKernelContext * const context,OpInputList * features_list)35 Status TensorUtils::ReadDenseFloatFeatures(OpKernelContext* const context,
36                                            OpInputList* features_list) {
37   // Constants.
38   constexpr auto kDenseFloatFeaturesName = "dense_float_features";
39 
40   // Read dense float features list;
41   TF_RETURN_IF_ERROR(
42       context->input_list(kDenseFloatFeaturesName, features_list));
43   return Status::OK();
44 }
45 
ReadSparseFloatFeatures(OpKernelContext * const context,OpInputList * features_indices_list,OpInputList * feature_values_list,OpInputList * feature_shapes_list)46 Status TensorUtils::ReadSparseFloatFeatures(OpKernelContext* const context,
47                                             OpInputList* features_indices_list,
48                                             OpInputList* feature_values_list,
49                                             OpInputList* feature_shapes_list) {
50   // Constants.
51   constexpr auto kSparseFloatFeatureIndicesName =
52       "sparse_float_feature_indices";
53   constexpr auto kSparseFloatFeatureValuesName = "sparse_float_feature_values";
54   constexpr auto kSparseFloatFeatureShapesName = "sparse_float_feature_shapes";
55 
56   // Read sparse float features list;
57   TF_RETURN_IF_ERROR(context->input_list(kSparseFloatFeatureIndicesName,
58                                          features_indices_list));
59   TF_RETURN_IF_ERROR(
60       context->input_list(kSparseFloatFeatureValuesName, feature_values_list));
61   TF_RETURN_IF_ERROR(
62       context->input_list(kSparseFloatFeatureShapesName, feature_shapes_list));
63   return Status::OK();
64 }
65 
ReadSparseIntFeatures(OpKernelContext * const context,OpInputList * features_indices_list,OpInputList * feature_values_list,OpInputList * feature_shapes_list)66 Status TensorUtils::ReadSparseIntFeatures(OpKernelContext* const context,
67                                           OpInputList* features_indices_list,
68                                           OpInputList* feature_values_list,
69                                           OpInputList* feature_shapes_list) {
70   // Constants.
71   constexpr auto kSparseIntFeatureIndicesName = "sparse_int_feature_indices";
72   constexpr auto kSparseIntFeatureValuesName = "sparse_int_feature_values";
73   constexpr auto kSparseIntFeatureShapesName = "sparse_int_feature_shapes";
74 
75   // Read sparse int features list;
76   TF_RETURN_IF_ERROR(
77       context->input_list(kSparseIntFeatureIndicesName, features_indices_list));
78   TF_RETURN_IF_ERROR(
79       context->input_list(kSparseIntFeatureValuesName, feature_values_list));
80   TF_RETURN_IF_ERROR(
81       context->input_list(kSparseIntFeatureShapesName, feature_shapes_list));
82   return Status::OK();
83 }
84 
InferBatchSize(const OpInputList & dense_float_features_list,const OpInputList & sparse_float_feature_shapes_list,const OpInputList & sparse_int_feature_shapes_list)85 int64 TensorUtils::InferBatchSize(
86     const OpInputList& dense_float_features_list,
87     const OpInputList& sparse_float_feature_shapes_list,
88     const OpInputList& sparse_int_feature_shapes_list) {
89   if (dense_float_features_list.size() > 0) {
90     return dense_float_features_list[0].dim_size(0);
91   }
92   if (sparse_float_feature_shapes_list.size() > 0) {
93     return sparse_float_feature_shapes_list[0].flat<int64>()(0);
94   }
95   if (sparse_int_feature_shapes_list.size() > 0) {
96     return sparse_int_feature_shapes_list[0].flat<int64>()(0);
97   }
98   LOG(QFATAL) << "Could not infer batch size due to empty feature set.";
99 }
100 
101 }  // namespace utils
102 }  // namespace boosted_trees
103 }  // namespace tensorflow
104