• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include "tensorflow/core/framework/common_shape_fns.h"
16 #include "tensorflow/core/framework/op.h"
17 #include "tensorflow/core/framework/shape_inference.h"
18 
19 namespace tensorflow {
20 
21 using shape_inference::DimensionHandle;
22 using shape_inference::InferenceContext;
23 using shape_inference::ShapeHandle;
24 
25 REGISTER_OP("BuildDenseInequalitySplits")
26     .Input("num_minibatches: int64")
27     .Input("partition_ids: int32")
28     .Input("bucket_ids: int64")
29     .Input("gradients: float32")
30     .Input("hessians: float32")
31     .Input("bucket_boundaries: float32")
32     .Input("class_id: int32")
33     .Input("feature_column_group_id: int32")
34     .Input("l1_regularization: float")
35     .Input("l2_regularization: float")
36     .Input("tree_complexity_regularization: float")
37     .Input("min_node_weight: float")
38     .Input("multiclass_strategy: int32")
39     .Input("weak_learner_type: int32")
40     .Output("output_partition_ids: int32")
41     .Output("gains: float32")
42     .Output("split_infos: string")
__anon295dc3ce0102(InferenceContext* c) 43     .SetShapeFn([](InferenceContext* c) {
44       DimensionHandle unused_dim;
45       ShapeHandle unused_shape;
46       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_shape));
47 
48       ShapeHandle partition_ids_shape;
49       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &partition_ids_shape));
50       ShapeHandle bucket_ids_shape;
51       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &bucket_ids_shape));
52       ShapeHandle gradients_shape;
53       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 1, &gradients_shape));
54       TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
55                                   c->Dim(gradients_shape, 0), &unused_dim));
56       ShapeHandle hessians_shape;
57       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(4), 1, &hessians_shape));
58       TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
59                                   c->Dim(hessians_shape, 0), &unused_dim));
60       ShapeHandle bucket_boundaries_shape;
61       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &bucket_boundaries_shape));
62       c->set_output(0, c->Vector(c->UnknownDim()));
63       c->set_output(1, c->Vector(c->UnknownDim()));
64       c->set_output(2, c->Vector(c->UnknownDim()));
65       return Status::OK();
66     })
67     .Doc(R"doc(
68 Find the split that has the best gain for the accumulated stats.
69 
70 num_minibatches: A scalar, the number of times per example gradients & hessians
71     were accumulated. The stats are divided by this to get per example stats.
72 partition_ids: A rank 1 tensor of partition IDs.
73 bucket_ids: A rank 2 tensor of buckets IDs and dimensions.
74 gradients: A rank 1 tensor of gradients.
75 hessians: A rank 1 tensor of hessians.
76 bucket_boundaries: A rank 1 tensor, thresholds that were used for bucketization.
77 class_id: A scalar, the class id for which we're building the splits.
78 feature_column_group_id: A scalar, the index of the feature we are spiltting on.
79 l1_regularization: A scalar, which specifies the l1 regularization term.
80 l2_regularization: A scalar, which specifies the l2 regularization term.
81 tree_complexity_regularization: A scalar, which specifies the tree complexity
82     regularization term.
83 min_node_weight: A scalar, minimum sum of example hessian needed in a child.
84     If a split results in a leaf node with a smaller value, the split will not
85     be considered.
86 multiclass_strategy: A scalar, specifying the multiclass handling strategy.
87     See LearnerConfig.MultiClassStrategy for valid values.
88 weak_learner_type: A scalar, specifying the weak learner type to use.
89     See LearnerConfig.WeakLearnerType for valid values.
90 output_partition_ids: A rank 1 tensor, the partition IDs that we created splits
91     for.
92 gains: A rank 1 tensor, for the computed gain for the created splits.
93 split_infos: A rank 1 tensor of serialized protos which contains the
94     `SplitInfo`s.
95 )doc");
96 
97 REGISTER_OP("BuildSparseInequalitySplits")
98     .Input("num_minibatches: int64")
99     .Input("partition_ids: int32")
100     .Input("bucket_ids: int64")
101     .Input("gradients: float32")
102     .Input("hessians: float32")
103     .Input("bucket_boundaries: float32")
104     .Input("class_id: int32")
105     .Input("feature_column_group_id: int32")
106     .Input("bias_feature_id: int64")
107     .Input("l1_regularization: float")
108     .Input("l2_regularization: float")
109     .Input("tree_complexity_regularization: float")
110     .Input("min_node_weight: float")
111     .Input("multiclass_strategy: int32")
112     .Output("output_partition_ids: int32")
113     .Output("gains: float32")
114     .Output("split_infos: string")
__anon295dc3ce0202(InferenceContext* c) 115     .SetShapeFn([](InferenceContext* c) {
116       DimensionHandle unused_dim;
117       ShapeHandle unused_shape;
118       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_shape));
119 
120       ShapeHandle partition_ids_shape;
121       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &partition_ids_shape));
122       ShapeHandle bucket_ids_shape;
123       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &bucket_ids_shape));
124       ShapeHandle gradients_shape;
125       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 1, &gradients_shape));
126       TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
127                                   c->Dim(gradients_shape, 0), &unused_dim));
128       ShapeHandle hessians_shape;
129       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(4), 1, &hessians_shape));
130       TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
131                                   c->Dim(hessians_shape, 0), &unused_dim));
132       ShapeHandle bucket_boundaries_shape;
133       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &bucket_boundaries_shape));
134       c->set_output(0, c->Vector(c->UnknownDim()));
135       c->set_output(1, c->Vector(c->UnknownDim()));
136       c->set_output(2, c->Vector(c->UnknownDim()));
137       return Status::OK();
138     })
139     .Doc(R"doc(
140 Find the split that has the best gain for the accumulated stats for a particular
141 feature column.
142 
143 num_minibatches: A scalar, the number of times per example gradients & hessians
144     were accumulated. The stats are divided by this to get per example stats.
145 partition_ids: A rank 2 tensor of partition IDs for each dimension of feature column.
146 bucket_ids: A rank 2 tensor of buckets IDs and dimensions.
147 gradients: A rank 1 tensor of gradients.
148 hessians: A rank 1 tensor of hessians.
149 bucket_boundaries: A rank 1 tensor, thresholds that were used for bucketization.
150 class_id: A scalar, the class id for which we're building the splits.
151 feature_column_group_id: A scalar, the index of the feature we are spiltting on.
152 l1_regularization: A scalar, which specifies the l1 regularization term.
153 l2_regularization: A scalar, which specifies the l2 regularization term.
154 tree_complexity_regularization: A scalar, which specifies the tree complexity
155     regularization term.
156 min_node_weight: A scalar, minimum sum of example hessian needed in a child.
157     If a split results in a leaf node with a smaller value, the split will not
158     be considered.
159 multiclass_strategy: A scalar, specifying the multiclass handling strategy.
160     See LearnerConfig.MultiClassStrategy for valid values.
161 output_partition_ids: A rank 1 tensor, the partition IDs that we created splits
162     for.
163 gains: A rank 1 tensor, for the computed gain for the created splits.
164 split_infos: A rank 1 tensor of serialized protos which contains the
165     `SplitInfo`s.
166 )doc");
167 
168 REGISTER_OP("BuildCategoricalEqualitySplits")
169     .Input("num_minibatches: int64")
170     .Input("partition_ids: int32")
171     .Input("feature_ids: int64")
172     .Input("gradients: float32")
173     .Input("hessians: float32")
174     .Input("class_id: int32")
175     .Input("feature_column_group_id: int32")
176     .Input("bias_feature_id: int64")
177     .Input("l1_regularization: float")
178     .Input("l2_regularization: float")
179     .Input("tree_complexity_regularization: float")
180     .Input("min_node_weight: float")
181     .Input("multiclass_strategy: int32")
182     .Input("weak_learner_type: int32")
183     .Output("output_partition_ids: int32")
184     .Output("gains: float32")
185     .Output("split_infos: string")
__anon295dc3ce0302(InferenceContext* c) 186     .SetShapeFn([](InferenceContext* c) {
187       DimensionHandle unused_dim;
188       ShapeHandle unused_shape;
189       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_shape));
190 
191       ShapeHandle partition_ids_shape;
192       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &partition_ids_shape));
193       ShapeHandle bucket_ids_shape;
194       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &bucket_ids_shape));
195       ShapeHandle gradients_shape;
196       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 1, &gradients_shape));
197       TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
198                                   c->Dim(gradients_shape, 0), &unused_dim));
199       ShapeHandle hessians_shape;
200       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(4), 1, &hessians_shape));
201       TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
202                                   c->Dim(hessians_shape, 0), &unused_dim));
203       c->set_output(0, c->Vector(c->UnknownDim()));
204       c->set_output(1, c->Vector(c->UnknownDim()));
205       c->set_output(2, c->Vector(c->UnknownDim()));
206       return Status::OK();
207     })
208     .Doc(R"doc(
209 Find the split that has the best gain for the accumulated stats.
210 
211 num_minibatches: A scalar, the number of times per example gradients & hessians
212     were accumulated. The stats are divided by this to get per example stats.
213 partition_ids: A rank 1 tensor of partition IDs.
214 feature_ids: A rank 2 tensor of feature IDs and dimensions.
215 gradients: A rank 1 tensor of gradients.
216 hessians: A rank 1 tensor of hessians.
217 class_id: A scalar, the class id for which we're building the splits.
218 feature_column_group_id: A scalar, the index of the feature we are spiltting on.
219 l1_regularization: A scalar, which specifies the l1 regularization term.
220 l2_regularization: A scalar, which specifies the l2 regularization term.
221 tree_complexity_regularization: A scalar, which specifies the tree complexity
222     regularization term.
223 min_node_weight: A scalar, minimum sum of example hessian needed in a child.
224     If a split results in a leaf node with a smaller value, the split will not
225     be considered.
226 multiclass_strategy: A scalar, specifying the multiclass handling strategy.
227     See LearnerConfig.MultiClassStrategy for valid values.
228 weak_learner_type: A scalar, specifying the weak learner type to use.
229     See LearnerConfig.WeakLearnerType for valid values.
230 output_partition_ids: A rank 1 tensor, the partition IDs that we created splits
231     for.
232 gains: A rank 1 tensor, for the computed gain for the created splits.
233 split_infos: A rank 1 tensor of serialized protos which contains the
234     `SplitInfo`s.
235 )doc");
236 
237 }  // namespace tensorflow
238