1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15
16 #include "tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h"
17 #include "tensorflow/contrib/boosted_trees/lib/utils/macros.h"
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/framework/tensor.h"
20
21 namespace tensorflow {
22 namespace boosted_trees {
23 namespace utils {
24
OpInputListToTensorVec(const OpInputList & input_list)25 std::vector<Tensor> TensorUtils::OpInputListToTensorVec(
26 const OpInputList& input_list) {
27 std::vector<Tensor> tensor_vec;
28 tensor_vec.reserve(input_list.size());
29 for (const Tensor& tensor : input_list) {
30 tensor_vec.emplace_back(tensor);
31 }
32 return tensor_vec;
33 }
34
ReadDenseFloatFeatures(OpKernelContext * const context,OpInputList * features_list)35 Status TensorUtils::ReadDenseFloatFeatures(OpKernelContext* const context,
36 OpInputList* features_list) {
37 // Constants.
38 constexpr auto kDenseFloatFeaturesName = "dense_float_features";
39
40 // Read dense float features list;
41 TF_RETURN_IF_ERROR(
42 context->input_list(kDenseFloatFeaturesName, features_list));
43 return Status::OK();
44 }
45
ReadSparseFloatFeatures(OpKernelContext * const context,OpInputList * features_indices_list,OpInputList * feature_values_list,OpInputList * feature_shapes_list)46 Status TensorUtils::ReadSparseFloatFeatures(OpKernelContext* const context,
47 OpInputList* features_indices_list,
48 OpInputList* feature_values_list,
49 OpInputList* feature_shapes_list) {
50 // Constants.
51 constexpr auto kSparseFloatFeatureIndicesName =
52 "sparse_float_feature_indices";
53 constexpr auto kSparseFloatFeatureValuesName = "sparse_float_feature_values";
54 constexpr auto kSparseFloatFeatureShapesName = "sparse_float_feature_shapes";
55
56 // Read sparse float features list;
57 TF_RETURN_IF_ERROR(context->input_list(kSparseFloatFeatureIndicesName,
58 features_indices_list));
59 TF_RETURN_IF_ERROR(
60 context->input_list(kSparseFloatFeatureValuesName, feature_values_list));
61 TF_RETURN_IF_ERROR(
62 context->input_list(kSparseFloatFeatureShapesName, feature_shapes_list));
63 return Status::OK();
64 }
65
ReadSparseIntFeatures(OpKernelContext * const context,OpInputList * features_indices_list,OpInputList * feature_values_list,OpInputList * feature_shapes_list)66 Status TensorUtils::ReadSparseIntFeatures(OpKernelContext* const context,
67 OpInputList* features_indices_list,
68 OpInputList* feature_values_list,
69 OpInputList* feature_shapes_list) {
70 // Constants.
71 constexpr auto kSparseIntFeatureIndicesName = "sparse_int_feature_indices";
72 constexpr auto kSparseIntFeatureValuesName = "sparse_int_feature_values";
73 constexpr auto kSparseIntFeatureShapesName = "sparse_int_feature_shapes";
74
75 // Read sparse int features list;
76 TF_RETURN_IF_ERROR(
77 context->input_list(kSparseIntFeatureIndicesName, features_indices_list));
78 TF_RETURN_IF_ERROR(
79 context->input_list(kSparseIntFeatureValuesName, feature_values_list));
80 TF_RETURN_IF_ERROR(
81 context->input_list(kSparseIntFeatureShapesName, feature_shapes_list));
82 return Status::OK();
83 }
84
InferBatchSize(const OpInputList & dense_float_features_list,const OpInputList & sparse_float_feature_shapes_list,const OpInputList & sparse_int_feature_shapes_list)85 int64 TensorUtils::InferBatchSize(
86 const OpInputList& dense_float_features_list,
87 const OpInputList& sparse_float_feature_shapes_list,
88 const OpInputList& sparse_int_feature_shapes_list) {
89 if (dense_float_features_list.size() > 0) {
90 return dense_float_features_list[0].dim_size(0);
91 }
92 if (sparse_float_feature_shapes_list.size() > 0) {
93 return sparse_float_feature_shapes_list[0].flat<int64>()(0);
94 }
95 if (sparse_int_feature_shapes_list.size() > 0) {
96 return sparse_int_feature_shapes_list[0].flat<int64>()(0);
97 }
98 LOG(QFATAL) << "Could not infer batch size due to empty feature set.";
99 }
100
101 } // namespace utils
102 } // namespace boosted_trees
103 } // namespace tensorflow
104