• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include "tensorflow/contrib/boosted_trees/proto/learner.pb.h"
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19 
20 using tensorflow::boosted_trees::learner::LearnerConfig;
21 
22 namespace tensorflow {
23 
24 using shape_inference::InferenceContext;
25 
ApplyGradientTreesPredictionShapeFn(InferenceContext * c)26 static Status ApplyGradientTreesPredictionShapeFn(InferenceContext* c) {
27   string learner_config_str;
28   // TODO(b/32704451): Don't just ignore the ::tensorflow::Status object!
29   c->GetAttr("learner_config", &learner_config_str).IgnoreError();
30   LearnerConfig learner_config;
31   ParseProtoUnlimited(&learner_config, learner_config_str);
32 
33   bool reduce_dim;
34   c->GetAttr("reduce_dim", &reduce_dim).IgnoreError();
35   // Sets the shape of the output as a matrix.
36   c->set_output(0, {c->Matrix(InferenceContext::kUnknownDim,
37                               reduce_dim ? learner_config.num_classes() - 1
38                                          : learner_config.num_classes())});
39   c->set_output(1, {c->UnknownShape()});
40   return Status::OK();
41 }
42 
ApplyGradientTreesPredictionVerboseShapeFn(InferenceContext * c)43 static Status ApplyGradientTreesPredictionVerboseShapeFn(InferenceContext* c) {
44   string learner_config_str;
45   c->GetAttr("learner_config", &learner_config_str).IgnoreError();
46   LearnerConfig learner_config;
47   ParseProtoUnlimited(&learner_config, learner_config_str);
48 
49   bool reduce_dim;
50   c->GetAttr("reduce_dim", &reduce_dim).IgnoreError();
51   // Sets the shape of the output as a matrix.
52   c->set_output(0, {c->Matrix(InferenceContext::kUnknownDim,
53                               reduce_dim ? learner_config.num_classes() - 1
54                                          : learner_config.num_classes())});
55   c->set_output(1, {c->UnknownShape()});
56   c->set_output(2, {c->Matrix(InferenceContext::kUnknownDim,
57                               InferenceContext::kUnknownDim)});
58   return Status::OK();
59 }
60 
61 REGISTER_OP("GradientTreesPrediction")
62     .Attr("learner_config: string")
63     .Attr("num_dense_float_features: int >= 0")
64     .Attr("num_sparse_float_features: int >= 0")
65     .Attr("num_sparse_int_features: int >= 0")
66     .Attr("use_locking: bool = false")
67     .Attr("apply_dropout: bool")
68     .Attr("apply_averaging: bool")
69     .Attr("center_bias: bool")
70     .Attr("reduce_dim: bool")
71     .Input("tree_ensemble_handle: resource")
72     .Input("seed: int64")
73     .Input("dense_float_features: num_dense_float_features * float")
74     .Input("sparse_float_feature_indices: num_sparse_float_features * int64")
75     .Input("sparse_float_feature_values: num_sparse_float_features * float")
76     .Input("sparse_float_feature_shapes: num_sparse_float_features * int64")
77     .Input("sparse_int_feature_indices: num_sparse_int_features * int64")
78     .Input("sparse_int_feature_values: num_sparse_int_features * int64")
79     .Input("sparse_int_feature_shapes: num_sparse_int_features * int64")
80     .Output("predictions: float")
81     .Output("drop_out_tree_indices_weights: float")
82     .SetShapeFn(ApplyGradientTreesPredictionShapeFn)
83     .Doc(R"doc(
84 Runs multiple additive regression forests predictors on input instances
85 and computes the final prediction for each class.
86 
87 learner_config: Config for the learner of type LearnerConfig proto. Prediction
88 ops for now uses only LearningRateDropoutDrivenConfig config from the learner.
89 num_dense_float_features: Number of dense float features.
90 num_sparse_float_features: Number of sparse float features.
91 num_sparse_int_features: Number of sparse int features.
92 use_locking: Whether to use locking.
93 seed: random seed to be used for dropout.
94 reduce_dim: whether to reduce the dimension (legacy impl) or not.
95 apply_dropout: whether to apply dropout during prediction.
96 apply_averaging: whether averaging of tree ensembles should take place. If set
97 to true, will be based on AveragingConfig from learner_config.
98 tree_ensemble_handle: The handle to the tree ensemble.
99 dense_float_features: Rank 2 Tensors containing dense float feature values.
100 sparse_float_feature_indices: Rank 2 Tensors containing sparse float indices.
101 sparse_float_feature_values: Rank 1 Tensors containing sparse float values.
102 sparse_float_feature_shapes: Rank 1 Tensors containing sparse float shapes.
103 sparse_int_feature_indices: Rank 2 Tensors containing sparse int indices.
104 sparse_int_feature_values: Rank 1 Tensors containing sparse int values.
105 sparse_int_feature_shapes: Rank 1 Tensors containing sparse int shapes.
106 predictions: Rank 2 Tensor containing predictions per example per class.
107 drop_out_tree_indices_weights: Tensor of Rank 2 containing dropped trees indices
108 and original weights of those trees during prediction.
109 )doc");
110 
111 REGISTER_OP("GradientTreesPredictionVerbose")
112     .Attr("learner_config: string")
113     .Attr("num_dense_float_features: int >= 0")
114     .Attr("num_sparse_float_features: int >= 0")
115     .Attr("num_sparse_int_features: int >= 0")
116     .Attr("use_locking: bool = false")
117     .Attr("apply_dropout: bool")
118     .Attr("apply_averaging: bool")
119     .Attr("center_bias: bool")
120     .Attr("reduce_dim: bool")
121     .Input("tree_ensemble_handle: resource")
122     .Input("seed: int64")
123     .Input("dense_float_features: num_dense_float_features * float")
124     .Input("sparse_float_feature_indices: num_sparse_float_features * int64")
125     .Input("sparse_float_feature_values: num_sparse_float_features * float")
126     .Input("sparse_float_feature_shapes: num_sparse_float_features * int64")
127     .Input("sparse_int_feature_indices: num_sparse_int_features * int64")
128     .Input("sparse_int_feature_values: num_sparse_int_features * int64")
129     .Input("sparse_int_feature_shapes: num_sparse_int_features * int64")
130     .Output("predictions: float")
131     .Output("drop_out_tree_indices_weights: float")
132     .Output("leaf_index: int32")
133     .SetShapeFn(ApplyGradientTreesPredictionVerboseShapeFn)
134     .Doc(R"doc(
135 Runs multiple additive regression forests predictors on input instances
136 and computes the final prediction for each class, and outputs a matrix of
137 leaf ids per each tree in an ensemble.
138 
139 learner_config: Config for the learner of type LearnerConfig proto. Prediction
140 ops for now uses only LearningRateDropoutDrivenConfig config from the learner.
141 num_dense_float_features: Number of dense float features.
142 num_sparse_float_features: Number of sparse float features.
143 num_sparse_int_features: Number of sparse int features.
144 use_locking: Whether to use locking.
145 seed: random seed to be used for dropout.
146 reduce_dim: whether to reduce the dimension (legacy impl) or not.
147 apply_dropout: whether to apply dropout during prediction.
148 apply_averaging: whether averaging of tree ensembles should take place. If set
149 to true, will be based on AveragingConfig from learner_config.
150 tree_ensemble_handle: The handle to the tree ensemble.
151 dense_float_features: Rank 2 Tensors containing dense float feature values.
152 sparse_float_feature_indices: Rank 2 Tensors containing sparse float indices.
153 sparse_float_feature_values: Rank 1 Tensors containing sparse float values.
154 sparse_float_feature_shapes: Rank 1 Tensors containing sparse float shapes.
155 sparse_int_feature_indices: Rank 2 Tensors containing sparse int indices.
156 sparse_int_feature_values: Rank 1 Tensors containing sparse int values.
157 sparse_int_feature_shapes: Rank 1 Tensors containing sparse int shapes.
158 predictions: Rank 2 Tensor containing predictions per example per class.
159 drop_out_tree_indices_weights: Tensor of Rank 2 containing dropped trees indices
160 leaf_index: tensor of rank 2 containing leaf ids for each tree where an instance ended up.
161 )doc");
162 
163 REGISTER_OP("GradientTreesPartitionExamples")
164     .Attr("num_dense_float_features: int >= 0")
165     .Attr("num_sparse_float_features: int >= 0")
166     .Attr("num_sparse_int_features: int >= 0")
167     .Attr("use_locking: bool = false")
168     .Input("tree_ensemble_handle: resource")
169     .Input("dense_float_features: num_dense_float_features * float")
170     .Input("sparse_float_feature_indices: num_sparse_float_features * int64")
171     .Input("sparse_float_feature_values: num_sparse_float_features * float")
172     .Input("sparse_float_feature_shapes: num_sparse_float_features * int64")
173     .Input("sparse_int_feature_indices: num_sparse_int_features * int64")
174     .Input("sparse_int_feature_values: num_sparse_int_features * int64")
175     .Input("sparse_int_feature_shapes: num_sparse_int_features * int64")
176     .Output("partition_ids: int32")
__anonbe66d2360102(InferenceContext* c) 177     .SetShapeFn([](InferenceContext* c) {
178       return c->set_output("partition_ids",
179                            {c->Vector(InferenceContext::kUnknownDim)});
180     })
181     .Doc(R"doc(
182 Splits input examples into the leaves of the tree.
183 
184 num_dense_float_features: Number of dense float features.
185 num_sparse_float_features: Number of sparse float features.
186 num_sparse_int_features: Number of sparse int features.
187 use_locking: Whether to use locking.
188 tree_ensemble_handle: The handle to the tree ensemble.
189 dense_float_features: Rank 2 Tensors containing dense float feature values.
190 sparse_float_feature_indices: Rank 2 Tensors containing sparse float indices.
191 sparse_float_feature_values: Rank 1 Tensors containing sparse float values.
192 sparse_float_feature_shapes: Rank 1 Tensors containing sparse float shapes.
193 sparse_int_feature_indices: Rank 2 Tensors containing sparse int indices.
194 sparse_int_feature_values: Rank 1 Tensors containing sparse int values.
195 sparse_int_feature_shapes: Rank 1 Tensors containing sparse int shapes.
196 partition_ids: Rank 1 Tensor containing partition ids per example.
197 )doc");
198 
199 }  // namespace tensorflow
200