1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <limits>
17 #include <vector>
18
19 #include "third_party/eigen3/Eigen/Core"
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/framework/tensor_shape.h"
23 #include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
24 #include "tensorflow/core/kernels/boosted_trees/tree_helper.h"
25 #include "tensorflow/core/platform/logging.h"
26
27 namespace tensorflow {
28
29 using Matrix =
30 Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
31 using ConstMatrixMap = Eigen::Map<const Matrix>;
32 using MatrixMap = Eigen::Map<Matrix>;
33
34 using ConstVectorMap = Eigen::Map<const Eigen::VectorXf>;
35 using VectorMap = Eigen::Map<Eigen::VectorXf>;
36
37 constexpr char kInequalitySplit[] = "inequality";
38 constexpr char kEqualitySplit[] = "equality";
39
40 // V1 Op. Deprecated. BoostedTreesCalculateBestFeatureSplitOpV2 is V2.
41 class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel {
42 public:
BoostedTreesCalculateBestGainsPerFeatureOp(OpKernelConstruction * const context)43 explicit BoostedTreesCalculateBestGainsPerFeatureOp(
44 OpKernelConstruction* const context)
45 : OpKernel(context) {
46 OP_REQUIRES_OK(context, context->GetAttr("max_splits", &max_splits_));
47 OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_));
48 }
49
Compute(OpKernelContext * const context)50 void Compute(OpKernelContext* const context) override {
51 // node_id_range
52 const Tensor* node_id_range_t;
53 OP_REQUIRES_OK(context, context->input("node_id_range", &node_id_range_t));
54 const auto node_id_range = node_id_range_t->vec<int32>();
55 const int32 node_id_first = node_id_range(0); // inclusive
56 const int32 node_id_last = node_id_range(1); // exclusive
57 // stats_summary_list
58 OpInputList stats_summary_list;
59 OP_REQUIRES_OK(context, context->input_list("stats_summary_list",
60 &stats_summary_list));
61 const int64 num_buckets = stats_summary_list[0].dim_size(1);
62 // Check for single logit: 1 gradient + 1 hessian value.
63 DCHECK_EQ(stats_summary_list[0].dim_size(2), 2);
64 std::vector<TTypes<float, 3>::ConstTensor> stats_summary;
65 stats_summary.reserve(stats_summary_list.size());
66 for (const auto& tensor : stats_summary_list) {
67 stats_summary.emplace_back(tensor.tensor<float, 3>());
68 }
69 const Tensor* l1_t;
70 OP_REQUIRES_OK(context, context->input("l1", &l1_t));
71 const auto l1 = l1_t->scalar<float>()();
72 const Tensor* l2_t;
73 OP_REQUIRES_OK(context, context->input("l2", &l2_t));
74 const auto l2 = l2_t->scalar<float>()();
75 const Tensor* tree_complexity_t;
76 OP_REQUIRES_OK(context,
77 context->input("tree_complexity", &tree_complexity_t));
78 const auto tree_complexity = tree_complexity_t->scalar<float>()();
79 const Tensor* min_node_weight_t;
80 OP_REQUIRES_OK(context,
81 context->input("min_node_weight", &min_node_weight_t));
82 const auto min_node_weight = min_node_weight_t->scalar<float>()();
83
84 // Allocate output lists of tensors:
85 OpOutputList output_node_ids_list;
86 OP_REQUIRES_OK(
87 context, context->output_list("node_ids_list", &output_node_ids_list));
88 OpOutputList output_gains_list;
89 OP_REQUIRES_OK(context,
90 context->output_list("gains_list", &output_gains_list));
91 OpOutputList output_thresholds_list;
92 OP_REQUIRES_OK(context, context->output_list("thresholds_list",
93 &output_thresholds_list));
94 OpOutputList output_left_node_contribs_list;
95 OP_REQUIRES_OK(context,
96 context->output_list("left_node_contribs_list",
97 &output_left_node_contribs_list));
98 OpOutputList output_right_node_contribs_list;
99 OP_REQUIRES_OK(context,
100 context->output_list("right_node_contribs_list",
101 &output_right_node_contribs_list));
102
103 // Use identity later to convert float to Eigen::Matrix type for input to
104 // CalculateWeightsAndGains. This op only supports single dimension logits.
105 Eigen::MatrixXf identity;
106 identity.setIdentity(1, 1);
107 // Get the best split info per node for each feature.
108 for (int feature_idx = 0; feature_idx < num_features_; ++feature_idx) {
109 std::vector<float> cum_grad;
110 std::vector<float> cum_hess;
111 cum_grad.reserve(num_buckets);
112 cum_hess.reserve(num_buckets);
113
114 std::vector<int32> output_node_ids;
115 std::vector<float> output_gains;
116 std::vector<int32> output_thresholds;
117 std::vector<float> output_left_node_contribs;
118 std::vector<float> output_right_node_contribs;
119 for (int node_id = node_id_first; node_id < node_id_last; ++node_id) {
120 // Calculate gains.
121 cum_grad.clear();
122 cum_hess.clear();
123 float total_grad = 0.0;
124 float total_hess = 0.0;
125 for (int bucket = 0; bucket < num_buckets; ++bucket) {
126 // TODO(nponomareva): Consider multi-dimensional gradients/hessians.
127 total_grad += stats_summary[feature_idx](node_id, bucket, 0);
128 total_hess += stats_summary[feature_idx](node_id, bucket, 1);
129 cum_grad.push_back(total_grad);
130 cum_hess.push_back(total_hess);
131 }
132 // Check if node has enough of average hessian.
133 if (total_hess < min_node_weight) {
134 // Do not split the node because not enough avg hessian.
135 continue;
136 }
137 float best_gain = std::numeric_limits<float>::lowest();
138 float best_bucket = 0;
139 float best_contrib_for_left = 0.0;
140 float best_contrib_for_right = 0.0;
141 // Parent gain.
142 float parent_gain;
143 Eigen::VectorXf unused(1);
144 CalculateWeightsAndGains(total_grad * identity, total_hess * identity,
145 l1, l2, &unused, &parent_gain);
146
147 for (int bucket = 0; bucket < num_buckets; ++bucket) {
148 const float cum_grad_bucket = cum_grad[bucket];
149 const float cum_hess_bucket = cum_hess[bucket];
150 // Left child.
151 Eigen::VectorXf contrib_for_left(1);
152 float gain_for_left;
153 CalculateWeightsAndGains(cum_grad_bucket * identity,
154 cum_hess_bucket * identity, l1, l2,
155 &contrib_for_left, &gain_for_left);
156 // Right child.
157 // use contrib_for_right.
158 Eigen::VectorXf contrib_for_right(1);
159 float gain_for_right;
160 CalculateWeightsAndGains((total_grad - cum_grad_bucket) * identity,
161 (total_hess - cum_hess_bucket) * identity,
162 l1, l2, &contrib_for_right, &gain_for_right);
163
164 if (GainIsLarger(gain_for_left + gain_for_right, best_gain)) {
165 best_gain = gain_for_left + gain_for_right;
166 best_bucket = bucket;
167 best_contrib_for_left = contrib_for_left[0];
168 best_contrib_for_right = contrib_for_right[0];
169 }
170 } // for bucket
171 output_node_ids.push_back(node_id);
172 // Remove the parent gain for the parent node.
173 output_gains.push_back(best_gain - parent_gain);
174 output_thresholds.push_back(best_bucket);
175 output_left_node_contribs.push_back(best_contrib_for_left);
176 output_right_node_contribs.push_back(best_contrib_for_right);
177 } // for node_id
178 const int num_nodes = output_node_ids.size();
179 // output_node_ids
180 Tensor* output_node_ids_t;
181 OP_REQUIRES_OK(context,
182 output_node_ids_list.allocate(feature_idx, {num_nodes},
183 &output_node_ids_t));
184 auto output_node_ids_vec = output_node_ids_t->vec<int32>();
185 // output_gains
186 Tensor* output_gains_t;
187 OP_REQUIRES_OK(context, output_gains_list.allocate(
188 feature_idx, {num_nodes}, &output_gains_t));
189 auto output_gains_vec = output_gains_t->vec<float>();
190 // output_thresholds
191 Tensor* output_thresholds_t;
192 OP_REQUIRES_OK(context,
193 output_thresholds_list.allocate(feature_idx, {num_nodes},
194 &output_thresholds_t));
195 auto output_thresholds_vec = output_thresholds_t->vec<int32>();
196 // output_left_node_contribs
197 Tensor* output_left_node_contribs_t;
198 OP_REQUIRES_OK(context, output_left_node_contribs_list.allocate(
199 feature_idx, {num_nodes, 1},
200 &output_left_node_contribs_t));
201 auto output_left_node_contribs_matrix =
202 output_left_node_contribs_t->matrix<float>();
203 // output_right_node_contribs
204 Tensor* output_right_node_contribs_t;
205 OP_REQUIRES_OK(context, output_right_node_contribs_list.allocate(
206 feature_idx, {num_nodes, 1},
207 &output_right_node_contribs_t));
208 auto output_right_node_contribs_matrix =
209 output_right_node_contribs_t->matrix<float>();
210 // Sets output tensors from vectors.
211 for (int i = 0; i < num_nodes; ++i) {
212 output_node_ids_vec(i) = output_node_ids[i];
213 // Adjust the gains to penalize by tree complexity.
214 output_gains_vec(i) = output_gains[i] - tree_complexity;
215 output_thresholds_vec(i) = output_thresholds[i];
216 output_left_node_contribs_matrix(i, 0) = output_left_node_contribs[i];
217 // This op only supports 1-dimensional logits.
218 output_right_node_contribs_matrix(i, 0) = output_right_node_contribs[i];
219 }
220 } // for f
221 }
222
223 private:
224 int max_splits_;
225 int num_features_;
226 };
227
228 // V1 op that only supports single dimensional logit.
229 REGISTER_KERNEL_BUILDER(
230 Name("BoostedTreesCalculateBestGainsPerFeature").Device(DEVICE_CPU),
231 BoostedTreesCalculateBestGainsPerFeatureOp);
232
233 // Deprecated op. Use BoostedTreesCalculateBestFeatureSplitOpV2.
234 class BoostedTreesCalculateBestFeatureSplitOp : public OpKernel {
235 public:
BoostedTreesCalculateBestFeatureSplitOp(OpKernelConstruction * const context)236 explicit BoostedTreesCalculateBestFeatureSplitOp(
237 OpKernelConstruction* const context)
238 : OpKernel(context) {
239 OP_REQUIRES_OK(context, context->GetAttr("logits_dimension", &logits_dim_));
240 OP_REQUIRES_OK(context, context->GetAttr("split_type", &split_type_));
241 }
242
Compute(OpKernelContext * const context)243 void Compute(OpKernelContext* const context) override {
244 // node_id_range
245 const Tensor* node_id_range_t;
246 OP_REQUIRES_OK(context, context->input("node_id_range", &node_id_range_t));
247 const auto node_id_range = node_id_range_t->vec<int32>();
248 const int32 node_id_first = node_id_range(0); // inclusive
249 const int32 node_id_last = node_id_range(1); // exclusive
250
251 const Tensor* stats_summary_t;
252 OP_REQUIRES_OK(context, context->input("stats_summary", &stats_summary_t));
253 TTypes<float, 4>::ConstTensor stats_summary =
254 stats_summary_t->tensor<float, 4>();
255 const int32 feature_dims = stats_summary_t->dim_size(1);
256 // The last bucket is for default/missing value.
257 const int32 num_buckets = stats_summary_t->dim_size(2) - 1;
258 const int32 logits_dim = logits_dim_;
259 const int32 hessian_dim = stats_summary_t->dim_size(3) - logits_dim;
260 DCHECK_GT(hessian_dim, 0);
261 DCHECK_LE(hessian_dim, logits_dim * logits_dim);
262
263 const Tensor* l1_t;
264 OP_REQUIRES_OK(context, context->input("l1", &l1_t));
265 const auto l1 = l1_t->scalar<float>()();
266 DCHECK_GE(l1, 0);
267 if (logits_dim_ > 1) {
268 // Multi-class L1 regularization not supported yet.
269 DCHECK_EQ(l1, 0);
270 }
271
272 const Tensor* l2_t;
273 OP_REQUIRES_OK(context, context->input("l2", &l2_t));
274 const auto l2 = l2_t->scalar<float>()();
275 DCHECK_GE(l2, 0);
276
277 const Tensor* tree_complexity_t;
278 OP_REQUIRES_OK(context,
279 context->input("tree_complexity", &tree_complexity_t));
280 const auto tree_complexity = tree_complexity_t->scalar<float>()();
281
282 const Tensor* min_node_weight_t;
283 OP_REQUIRES_OK(context,
284 context->input("min_node_weight", &min_node_weight_t));
285 const auto min_node_weight = min_node_weight_t->scalar<float>()();
286
287 std::vector<int32> output_node_ids;
288 std::vector<float> output_gains;
289 std::vector<int32> output_feature_dimensions;
290 std::vector<int32> output_thresholds;
291 std::vector<Eigen::VectorXf> output_left_node_contribs;
292 std::vector<Eigen::VectorXf> output_right_node_contribs;
293 std::vector<string> output_split_types;
294
295 // TODO(tanzheny) parallelize the computation.
296 // Iterate each node and find the best gain per node.
297 for (int32 node_id = node_id_first; node_id < node_id_last; ++node_id) {
298 float best_gain = std::numeric_limits<float>::lowest();
299 int32 best_bucket = 0;
300 int32 best_f_dim = 0;
301 string best_split_type;
302 Eigen::VectorXf best_contrib_for_left(logits_dim);
303 Eigen::VectorXf best_contrib_for_right(logits_dim);
304 float parent_gain;
305
306 // Including default bucket.
307 ConstMatrixMap stats_mat(&stats_summary(node_id, 0, 0, 0),
308 num_buckets + 1, logits_dim + hessian_dim);
309 const Eigen::VectorXf total_grad =
310 stats_mat.leftCols(logits_dim).colwise().sum();
311 const Eigen::VectorXf total_hess =
312 stats_mat.rightCols(hessian_dim).colwise().sum();
313 if (total_hess.norm() < min_node_weight) {
314 continue;
315 }
316 Eigen::VectorXf parent_weight(logits_dim);
317 CalculateWeightsAndGains(total_grad, total_hess, l1, l2, &parent_weight,
318 &parent_gain);
319
320 if (split_type_ == "inequality") {
321 CalculateBestInequalitySplit(
322 stats_summary, node_id, feature_dims, logits_dim, hessian_dim,
323 num_buckets, min_node_weight, l1, l2, &best_gain, &best_bucket,
324 &best_f_dim, &best_split_type, &best_contrib_for_left,
325 &best_contrib_for_right);
326 } else {
327 CalculateBestEqualitySplit(
328 stats_summary, total_grad, total_hess, node_id, feature_dims,
329 logits_dim, hessian_dim, num_buckets, l1, l2, &best_gain,
330 &best_bucket, &best_f_dim, &best_split_type, &best_contrib_for_left,
331 &best_contrib_for_right);
332 }
333
334 if (best_gain == std::numeric_limits<float>::lowest()) {
335 // Do not add the node if not split if found.
336 continue;
337 }
338 output_node_ids.push_back(node_id);
339 // Remove the parent gain for the parent node.
340 output_gains.push_back(best_gain - parent_gain);
341 output_feature_dimensions.push_back(best_f_dim);
342 // default direction is fixed for dense splits.
343 // TODO(tanzheny) account for default values.
344 output_split_types.push_back(best_split_type);
345 output_thresholds.push_back(best_bucket);
346 output_left_node_contribs.push_back(best_contrib_for_left);
347 output_right_node_contribs.push_back(best_contrib_for_right);
348 } // for node id
349 const int num_nodes = output_node_ids.size();
350 // output_node_ids
351 Tensor* output_node_ids_t = nullptr;
352 OP_REQUIRES_OK(context, context->allocate_output("node_ids", {num_nodes},
353 &output_node_ids_t));
354 auto output_node_ids_vec = output_node_ids_t->vec<int32>();
355
356 // output_gains
357 Tensor* output_gains_t;
358 OP_REQUIRES_OK(context, context->allocate_output("gains", {num_nodes},
359 &output_gains_t));
360 auto output_gains_vec = output_gains_t->vec<float>();
361
362 // output_feature_dimensions
363 Tensor* output_feature_dimension_t;
364 OP_REQUIRES_OK(context,
365 context->allocate_output("feature_dimensions", {num_nodes},
366 &output_feature_dimension_t));
367 auto output_feature_dimensions_vec =
368 output_feature_dimension_t->vec<int32>();
369
370 // output_thresholds
371 Tensor* output_thresholds_t;
372 OP_REQUIRES_OK(context, context->allocate_output("thresholds", {num_nodes},
373 &output_thresholds_t));
374 auto output_thresholds_vec = output_thresholds_t->vec<int32>();
375
376 // output_left_node_contribs
377 Tensor* output_left_node_contribs_t;
378 OP_REQUIRES_OK(context, context->allocate_output(
379 "left_node_contribs", {num_nodes, logits_dim},
380 &output_left_node_contribs_t));
381 auto output_left_node_contribs_matrix =
382 output_left_node_contribs_t->matrix<float>();
383
384 // output_right_node_contribs
385 Tensor* output_right_node_contribs_t;
386 OP_REQUIRES_OK(context, context->allocate_output(
387 "right_node_contribs", {num_nodes, logits_dim},
388 &output_right_node_contribs_t));
389 auto output_right_node_contribs_matrix =
390 output_right_node_contribs_t->matrix<float>();
391
392 // split type
393 Tensor* output_split_types_t;
394 OP_REQUIRES_OK(
395 context, context->allocate_output("split_with_default_directions",
396 {num_nodes}, &output_split_types_t));
397 auto output_split_types_vec = output_split_types_t->vec<tstring>();
398
399 // Sets output tensors from vectors.
400 for (int i = 0; i < num_nodes; ++i) {
401 output_node_ids_vec(i) = output_node_ids[i];
402 // Adjust the gains to penalize by tree complexity.
403 output_gains_vec(i) = output_gains[i] - tree_complexity;
404 output_feature_dimensions_vec(i) = output_feature_dimensions[i];
405 output_thresholds_vec(i) = output_thresholds[i];
406 for (int j = 0; j < logits_dim; ++j) {
407 output_left_node_contribs_matrix(i, j) =
408 output_left_node_contribs[i][j];
409 output_right_node_contribs_matrix(i, j) =
410 output_right_node_contribs[i][j];
411 }
412 output_split_types_vec(i) = output_split_types[i];
413 }
414 }
415
416 private:
417 // TODO(crawles): Simplify inequality path just like equality b/138329196
418 // Currently this is not simplify-able due to numerical instability in math
419 // i.e. gain = -g.transpose() * hessian_and_reg.colPivHouseholderQr().solve(g)
420 // It caused gain to be Inf when g is approaching 0 but not exactly 0 while
421 // there is no regularization.
422 // Calculate the best inequality split per node.
CalculateBestInequalitySplit(TTypes<float,4>::ConstTensor stats_summary,const int32 node_id,const int32 feature_dims,const int32 logits_dim,const int32 hessian_dim,const int32 num_buckets,const float min_node_weight,const float l1,const float l2,float * best_gain,int32 * best_bucket,int32 * best_f_dim,string * best_split_type,Eigen::VectorXf * best_contrib_for_left,Eigen::VectorXf * best_contrib_for_right)423 void CalculateBestInequalitySplit(
424 TTypes<float, 4>::ConstTensor stats_summary, const int32 node_id,
425 const int32 feature_dims, const int32 logits_dim, const int32 hessian_dim,
426 const int32 num_buckets, const float min_node_weight, const float l1,
427 const float l2, float* best_gain, int32* best_bucket, int32* best_f_dim,
428 string* best_split_type, Eigen::VectorXf* best_contrib_for_left,
429 Eigen::VectorXf* best_contrib_for_right) {
430 std::vector<Eigen::VectorXf> cum_grad;
431 std::vector<Eigen::VectorXf> cum_hess;
432 // get all cumulative gradients including default bucket.
433 cum_grad.reserve(num_buckets);
434 cum_hess.reserve(num_buckets);
435
436 for (int f_dim = 0; f_dim < feature_dims; ++f_dim) {
437 ConstVectorMap default_stats_vec(
438 &stats_summary(node_id, f_dim, num_buckets, 0),
439 logits_dim + hessian_dim);
440 Eigen::VectorXf missing_bucket_grad = default_stats_vec.head(logits_dim);
441 Eigen::VectorXf missing_bucket_hess = default_stats_vec.tail(hessian_dim);
442 cum_grad.clear();
443 cum_hess.clear();
444 Eigen::VectorXf total_grad = Eigen::VectorXf::Zero(logits_dim);
445 Eigen::VectorXf total_hess = Eigen::VectorXf::Zero(hessian_dim);
446 // sum all the gradients including default bucket.
447 for (int bucket = 0; bucket <= num_buckets; ++bucket) {
448 for (int i = 0; i < logits_dim; ++i) {
449 total_grad[i] += stats_summary(node_id, f_dim, bucket, i);
450 }
451 for (int i = 0; i < hessian_dim; ++i) {
452 // Full hessian.
453 total_hess[i] +=
454 stats_summary(node_id, f_dim, bucket, logits_dim + i);
455 }
456 if (bucket < num_buckets) {
457 cum_grad.push_back(total_grad);
458 cum_hess.push_back(total_hess);
459 }
460 }
461 const string kInequalityDefaultLeft =
462 boosted_trees::SplitTypeWithDefault_Name(
463 boosted_trees::INEQUALITY_DEFAULT_LEFT);
464 const string kInequalityDefaultRight =
465 boosted_trees::SplitTypeWithDefault_Name(
466 boosted_trees::INEQUALITY_DEFAULT_RIGHT);
467
468 // Iterate from left to right, excluding default bucket.
469 for (int bucket = 0; bucket < num_buckets; ++bucket) {
470 // default value goes to left node.
471 const Eigen::VectorXf total_left_grad =
472 cum_grad[bucket] + missing_bucket_grad;
473 const Eigen::VectorXf total_left_hess =
474 cum_hess[bucket] + missing_bucket_hess;
475 MaybeUpdateBestSplit(
476 total_left_grad, total_grad - total_left_grad, total_left_hess,
477 total_hess - total_left_hess, logits_dim, bucket, f_dim, l1, l2,
478 kInequalityDefaultLeft, best_gain, best_bucket, best_f_dim,
479 best_split_type, best_contrib_for_left, best_contrib_for_right);
480 // default value goes to right node.
481 MaybeUpdateBestSplit(
482 cum_grad[bucket], total_grad - cum_grad[bucket], cum_hess[bucket],
483 total_hess - cum_hess[bucket], logits_dim, bucket, f_dim, l1, l2,
484 kInequalityDefaultRight, best_gain, best_bucket, best_f_dim,
485 best_split_type, best_contrib_for_left, best_contrib_for_right);
486 } // for bucket
487 }
488 }
489
490 // Calculate the best equality split per node.
CalculateBestEqualitySplit(TTypes<float,4>::ConstTensor stats_summary,const Eigen::VectorXf & total_grad,const Eigen::VectorXf & total_hess,const int32 node_id,const int32 feature_dims,const int32 logits_dim,const int32 hessian_dim,const int32 num_buckets,const float l1,const float l2,float * best_gain,int32 * best_bucket,int32 * best_f_dim,string * best_split_type,Eigen::VectorXf * best_contrib_for_left,Eigen::VectorXf * best_contrib_for_right)491 void CalculateBestEqualitySplit(
492 TTypes<float, 4>::ConstTensor stats_summary,
493 const Eigen::VectorXf& total_grad, const Eigen::VectorXf& total_hess,
494 const int32 node_id, const int32 feature_dims, const int32 logits_dim,
495 const int32 hessian_dim, const int32 num_buckets, const float l1,
496 const float l2, float* best_gain, int32* best_bucket, int32* best_f_dim,
497 string* best_split_type, Eigen::VectorXf* best_contrib_for_left,
498 Eigen::VectorXf* best_contrib_for_right) {
499 const string kEqualityDefaultRight =
500 boosted_trees::SplitTypeWithDefault_Name(
501 boosted_trees::EQUALITY_DEFAULT_RIGHT);
502 for (int f_dim = 0; f_dim < feature_dims; ++f_dim) {
503 for (int bucket = 0; bucket < num_buckets; ++bucket) {
504 ConstVectorMap stats_vec(&stats_summary(node_id, f_dim, bucket, 0),
505 logits_dim + hessian_dim);
506 Eigen::VectorXf curr_grad = stats_vec.head(logits_dim);
507 Eigen::VectorXf curr_hess = stats_vec.tail(hessian_dim);
508 MaybeUpdateBestSplit(curr_grad, total_grad - curr_grad, curr_hess,
509 total_hess - curr_hess, logits_dim, bucket, f_dim,
510 l1, l2, kEqualityDefaultRight, best_gain,
511 best_bucket, best_f_dim, best_split_type,
512 best_contrib_for_left, best_contrib_for_right);
513 }
514 }
515 }
516
MaybeUpdateBestSplit(const Eigen::VectorXf & grad_for_left,const Eigen::VectorXf & grad_for_right,const Eigen::VectorXf & hess_for_left,const Eigen::VectorXf & hess_for_right,const int32 logits_dim,const int32 bucket,const int32 f_dim,const float l1,const float l2,const string split_type,float * best_gain,int32 * best_bucket,int32 * best_f_dim,string * best_split_type,Eigen::VectorXf * best_contrib_for_left,Eigen::VectorXf * best_contrib_for_right)517 void MaybeUpdateBestSplit(const Eigen::VectorXf& grad_for_left,
518 const Eigen::VectorXf& grad_for_right,
519 const Eigen::VectorXf& hess_for_left,
520 const Eigen::VectorXf& hess_for_right,
521 const int32 logits_dim, const int32 bucket,
522 const int32 f_dim, const float l1, const float l2,
523 const string split_type, float* best_gain,
524 int32* best_bucket, int32* best_f_dim,
525 string* best_split_type,
526 Eigen::VectorXf* best_contrib_for_left,
527 Eigen::VectorXf* best_contrib_for_right) {
528 // Left child.
529 Eigen::VectorXf contrib_for_left(logits_dim);
530 float gain_for_left;
531 CalculateWeightsAndGains(grad_for_left, hess_for_left, l1, l2,
532 &contrib_for_left, &gain_for_left);
533 Eigen::VectorXf contrib_for_right(logits_dim);
534 float gain_for_right;
535 CalculateWeightsAndGains(grad_for_right, hess_for_right, l1, l2,
536 &contrib_for_right, &gain_for_right);
537 if (GainIsLarger(gain_for_left + gain_for_right, *best_gain)) {
538 *best_gain = gain_for_left + gain_for_right;
539 *best_bucket = bucket;
540 *best_f_dim = f_dim;
541 *best_contrib_for_left = contrib_for_left;
542 *best_contrib_for_right = contrib_for_right;
543 *best_split_type = split_type;
544 }
545 }
546
547 int logits_dim_;
548 string split_type_;
549 };
550
551 // Deprecated op. Use BoostedTreesCalculateBestFeatureSplitOpV2.
552 REGISTER_KERNEL_BUILDER(
553 Name("BoostedTreesCalculateBestFeatureSplit").Device(DEVICE_CPU),
554 BoostedTreesCalculateBestFeatureSplitOp);
555
556 // V2 Op.
557 class BoostedTreesCalculateBestFeatureSplitV2 : public OpKernel {
558 public:
BoostedTreesCalculateBestFeatureSplitV2(OpKernelConstruction * const context)559 explicit BoostedTreesCalculateBestFeatureSplitV2(
560 OpKernelConstruction* const context)
561 : OpKernel(context) {
562 OP_REQUIRES_OK(context, context->GetAttr("logits_dimension", &logits_dim_));
563 OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_));
564 }
565
Compute(OpKernelContext * const context)566 void Compute(OpKernelContext* const context) override {
567 // node_id_range
568 const Tensor* node_id_range_t;
569 OP_REQUIRES_OK(context, context->input("node_id_range", &node_id_range_t));
570 const auto node_id_range = node_id_range_t->vec<int32>();
571 const int32 node_id_first = node_id_range(0); // Inclusive.
572 const int32 node_id_last = node_id_range(1); // Exclusive.
573
574 // Get stats_summaries_list.
575 OpInputList stats_summaries_list;
576 OP_REQUIRES_OK(context, context->input_list("stats_summaries_list",
577 &stats_summaries_list));
578
579 // Infer dimensions of a stats_summary.
580 DCHECK_GT(stats_summaries_list.size(), 0);
581 const int32 feature_dims = stats_summaries_list[0].dim_size(1);
582 // The last bucket is for default/missing value.
583 const int32 num_buckets = stats_summaries_list[0].dim_size(2) - 1;
584 const int32 logits_dim = logits_dim_;
585 const int32 hessian_dim = stats_summaries_list[0].dim_size(3) - logits_dim;
586 DCHECK_GT(hessian_dim, 0);
587 DCHECK_LE(hessian_dim, logits_dim * logits_dim);
588
589 // Vector of stats_summaries; each element is stats for feature of shape
590 // [max_splits, feature_dim, num_buckets, logits_dim + hessian_dim].
591 std::vector<TTypes<float, 4>::ConstTensor> stats_summaries;
592 DCHECK_EQ(stats_summaries_list.size(), num_features_);
593 stats_summaries.reserve(num_features_);
594 for (const auto& tensor : stats_summaries_list) {
595 stats_summaries.emplace_back(tensor.tensor<float, 4>());
596 }
597
598 // Split types.
599 const Tensor* split_types_t;
600 OP_REQUIRES_OK(context, context->input("split_types", &split_types_t));
601 const auto split_types = split_types_t->vec<tstring>();
602 DCHECK_EQ(split_types.size(), num_features_);
603 // Validate.
604 for (int i = 0; i < num_features_; ++i) {
605 if (!(split_types(i) == kInequalitySplit ||
606 split_types(i) == kEqualitySplit)) {
607 OP_REQUIRES_OK(
608 context,
609 errors::Aborted(
610 "Operation received an exception: Incorrect split type"));
611 }
612 }
613 // Feature ids.
614 const Tensor* candidate_feature_ids_t;
615 OP_REQUIRES_OK(context, context->input("candidate_feature_ids",
616 &candidate_feature_ids_t));
617 const auto candidate_feature_ids = candidate_feature_ids_t->vec<int32>();
618 DCHECK_EQ(candidate_feature_ids.size(), num_features_);
619
620 // L1, L2, tree_complexity, min_node_weight.
621 const Tensor* l1_t;
622 OP_REQUIRES_OK(context, context->input("l1", &l1_t));
623 const auto l1 = l1_t->scalar<float>()();
624 DCHECK_GE(l1, 0);
625 if (logits_dim_ > 1) {
626 // Multi-class L1 regularization not supported yet.
627 DCHECK_EQ(l1, 0);
628 }
629 const Tensor* l2_t;
630 OP_REQUIRES_OK(context, context->input("l2", &l2_t));
631 const auto l2 = l2_t->scalar<float>()();
632 DCHECK_GE(l2, 0);
633 const Tensor* tree_complexity_t;
634 OP_REQUIRES_OK(context,
635 context->input("tree_complexity", &tree_complexity_t));
636 const auto tree_complexity = tree_complexity_t->scalar<float>()();
637 const Tensor* min_node_weight_t;
638 OP_REQUIRES_OK(context,
639 context->input("min_node_weight", &min_node_weight_t));
640 const auto min_node_weight = min_node_weight_t->scalar<float>()();
641
642 std::vector<int32> output_node_ids;
643 std::vector<float> output_gains;
644 std::vector<int32> output_feature_ids;
645 std::vector<int32> output_feature_dimensions;
646 std::vector<int32> output_thresholds;
647 std::vector<Eigen::VectorXf> output_left_node_contribs;
648 std::vector<Eigen::VectorXf> output_right_node_contribs;
649 std::vector<string> output_split_types;
650
651 // TODO(tanzheny) parallelize the computation.
652 // Iterate each node and find the best gain per node.
653 float parent_gain;
654 for (int32 node_id = node_id_first; node_id < node_id_last; ++node_id) {
655 float best_gain = std::numeric_limits<float>::lowest();
656 int32 best_bucket;
657 int32 best_f_id;
658 int32 best_f_dim;
659 string best_split_type;
660 Eigen::VectorXf best_contrib_for_left(logits_dim);
661 Eigen::VectorXf best_contrib_for_right(logits_dim);
662
663 // Sum of gradient and hessian. Compute parent gain using first feature.
664 ConstMatrixMap stats_mat(&stats_summaries[0](node_id, 0, 0, 0),
665 num_buckets + 1, // Including default bucket.
666 logits_dim + hessian_dim);
667 const Eigen::VectorXf total_grad =
668 stats_mat.leftCols(logits_dim).colwise().sum();
669 const Eigen::VectorXf total_hess =
670 stats_mat.rightCols(hessian_dim).colwise().sum();
671 if (total_hess.norm() < min_node_weight) {
672 continue;
673 }
674 Eigen::VectorXf unused(logits_dim);
675 CalculateWeightsAndGains(total_grad, total_hess, l1, l2, &unused,
676 &parent_gain);
677 for (int f_idx = 0; f_idx < num_features_; ++f_idx) {
678 const string split_type = split_types(f_idx);
679 TTypes<float, 4>::ConstTensor stats_summary = stats_summaries[f_idx];
680 float f_best_gain = std::numeric_limits<float>::lowest();
681 int32 f_best_bucket;
682 int32 f_best_f_dim;
683 string f_best_split_type;
684 Eigen::VectorXf f_best_contrib_for_left(logits_dim);
685 Eigen::VectorXf f_best_contrib_for_right(logits_dim);
686
687 if (split_type == kInequalitySplit) {
688 CalculateBestInequalitySplit(
689 stats_summary, node_id, feature_dims, logits_dim, hessian_dim,
690 num_buckets, min_node_weight, l1, l2, &f_best_gain,
691 &f_best_bucket, &f_best_f_dim, &f_best_split_type,
692 &f_best_contrib_for_left, &f_best_contrib_for_right);
693 } else {
694 CalculateBestEqualitySplit(
695 stats_summary, total_grad, total_hess, node_id, feature_dims,
696 logits_dim, hessian_dim, num_buckets, l1, l2, &f_best_gain,
697 &f_best_bucket, &f_best_f_dim, &f_best_split_type,
698 &f_best_contrib_for_left, &f_best_contrib_for_right);
699 }
700 if (f_best_gain > best_gain) {
701 best_gain = f_best_gain;
702 best_f_id = candidate_feature_ids(f_idx);
703 best_f_dim = f_best_f_dim;
704 best_split_type = f_best_split_type;
705 best_bucket = f_best_bucket;
706 best_contrib_for_left = f_best_contrib_for_left;
707 best_contrib_for_right = f_best_contrib_for_right;
708 }
709 } // For feature id.
710 if (best_gain == std::numeric_limits<float>::lowest()) {
711 // Do not add the node if no split is found.
712 continue;
713 }
714 output_node_ids.push_back(node_id);
715 // Remove the parent gain for the parent node.
716 output_gains.push_back(best_gain - parent_gain);
717 output_feature_ids.push_back(best_f_id);
718 output_feature_dimensions.push_back(best_f_dim);
719 // Default direction is fixed for dense splits.
720 // TODO(tanzheny) account for default values.
721 output_split_types.push_back(best_split_type);
722 output_thresholds.push_back(best_bucket);
723 output_left_node_contribs.push_back(best_contrib_for_left);
724 output_right_node_contribs.push_back(best_contrib_for_right);
725 } // for node id.
726 const int num_nodes = output_node_ids.size();
727 // output_node_ids
728 Tensor* output_node_ids_t = nullptr;
729 OP_REQUIRES_OK(context, context->allocate_output("node_ids", {num_nodes},
730 &output_node_ids_t));
731 auto output_node_ids_vec = output_node_ids_t->vec<int32>();
732
733 // output_gains
734 Tensor* output_gains_t;
735 OP_REQUIRES_OK(context, context->allocate_output("gains", {num_nodes},
736 &output_gains_t));
737 auto output_gains_vec = output_gains_t->vec<float>();
738
739 // output_feature_ids
740 Tensor* output_features_ids_t;
741 OP_REQUIRES_OK(context, context->allocate_output("feature_ids", {num_nodes},
742 &output_features_ids_t));
743 auto output_features_vec = output_features_ids_t->vec<int32>();
744
745 // output_feature_dimensions
746 Tensor* output_feature_dimension_t;
747 OP_REQUIRES_OK(context,
748 context->allocate_output("feature_dimensions", {num_nodes},
749 &output_feature_dimension_t));
750 auto output_feature_dimensions_vec =
751 output_feature_dimension_t->vec<int32>();
752
753 // output_thresholds
754 Tensor* output_thresholds_t;
755 OP_REQUIRES_OK(context, context->allocate_output("thresholds", {num_nodes},
756 &output_thresholds_t));
757 auto output_thresholds_vec = output_thresholds_t->vec<int32>();
758
759 // output_left_node_contribs
760 Tensor* output_left_node_contribs_t;
761 OP_REQUIRES_OK(context, context->allocate_output(
762 "left_node_contribs", {num_nodes, logits_dim},
763 &output_left_node_contribs_t));
764 auto output_left_node_contribs_matrix =
765 output_left_node_contribs_t->matrix<float>();
766
767 // output_right_node_contribs
768 Tensor* output_right_node_contribs_t;
769 OP_REQUIRES_OK(context, context->allocate_output(
770 "right_node_contribs", {num_nodes, logits_dim},
771 &output_right_node_contribs_t));
772 auto output_right_node_contribs_matrix =
773 output_right_node_contribs_t->matrix<float>();
774
775 // split type
776 Tensor* output_split_types_t;
777 OP_REQUIRES_OK(
778 context, context->allocate_output("split_with_default_directions",
779 {num_nodes}, &output_split_types_t));
780 auto output_split_types_vec = output_split_types_t->vec<tstring>();
781
782 // Sets output tensors from vectors.
783 for (int i = 0; i < num_nodes; ++i) {
784 output_node_ids_vec(i) = output_node_ids[i];
785 output_features_vec(i) = output_feature_ids[i];
786 // Adjust the gains to penalize by tree complexity.
787 output_gains_vec(i) = output_gains[i] - tree_complexity;
788 output_feature_dimensions_vec(i) = output_feature_dimensions[i];
789 output_thresholds_vec(i) = output_thresholds[i];
790 for (int j = 0; j < logits_dim; ++j) {
791 output_left_node_contribs_matrix(i, j) =
792 output_left_node_contribs[i][j];
793 output_right_node_contribs_matrix(i, j) =
794 output_right_node_contribs[i][j];
795 }
796 output_split_types_vec(i) = output_split_types[i];
797 }
798 }
799
800 private:
801 // TODO(crawles): Simplify inequality path just like equality b/138329196
802 // Currently this is not simplify-able due to numerical instability in math
803 // i.e. gain = -g.transpose() * hessian_and_reg.colPivHouseholderQr().solve(g)
804 // It caused gain to be Inf when g is approaching 0 but not exactly 0 while
805 // there is no regularization.
806 // Calculate the best inequality split per node.
CalculateBestInequalitySplit(TTypes<float,4>::ConstTensor stats_summary,const int32 node_id,const int32 feature_dims,const int32 logits_dim,const int32 hessian_dim,const int32 num_buckets,const float min_node_weight,const float l1,const float l2,float * best_gain,int32 * best_bucket,int32 * best_f_dim,string * best_split_type,Eigen::VectorXf * best_contrib_for_left,Eigen::VectorXf * best_contrib_for_right)807 void CalculateBestInequalitySplit(
808 TTypes<float, 4>::ConstTensor stats_summary, const int32 node_id,
809 const int32 feature_dims, const int32 logits_dim, const int32 hessian_dim,
810 const int32 num_buckets, const float min_node_weight, const float l1,
811 const float l2, float* best_gain, int32* best_bucket, int32* best_f_dim,
812 string* best_split_type, Eigen::VectorXf* best_contrib_for_left,
813 Eigen::VectorXf* best_contrib_for_right) {
814 std::vector<Eigen::VectorXf> cum_grad;
815 std::vector<Eigen::VectorXf> cum_hess;
816 // get all cumulative gradients including default bucket.
817 cum_grad.reserve(num_buckets);
818 cum_hess.reserve(num_buckets);
819
820 for (int f_dim = 0; f_dim < feature_dims; ++f_dim) {
821 ConstVectorMap default_stats_vec(
822 &stats_summary(node_id, f_dim, num_buckets, 0),
823 logits_dim + hessian_dim);
824 Eigen::VectorXf missing_bucket_grad = default_stats_vec.head(logits_dim);
825 Eigen::VectorXf missing_bucket_hess = default_stats_vec.tail(hessian_dim);
826 cum_grad.clear();
827 cum_hess.clear();
828 Eigen::VectorXf total_grad = Eigen::VectorXf::Zero(logits_dim);
829 Eigen::VectorXf total_hess = Eigen::VectorXf::Zero(hessian_dim);
830 // sum all the gradients including default bucket.
831 for (int bucket = 0; bucket <= num_buckets; ++bucket) {
832 for (int i = 0; i < logits_dim; ++i) {
833 total_grad[i] += stats_summary(node_id, f_dim, bucket, i);
834 }
835 for (int i = 0; i < hessian_dim; ++i) {
836 // Full hessian.
837 total_hess[i] +=
838 stats_summary(node_id, f_dim, bucket, logits_dim + i);
839 }
840 if (bucket < num_buckets) {
841 cum_grad.push_back(total_grad);
842 cum_hess.push_back(total_hess);
843 }
844 }
845 const string kInequalityDefaultLeft =
846 boosted_trees::SplitTypeWithDefault_Name(
847 boosted_trees::INEQUALITY_DEFAULT_LEFT);
848 const string kInequalityDefaultRight =
849 boosted_trees::SplitTypeWithDefault_Name(
850 boosted_trees::INEQUALITY_DEFAULT_RIGHT);
851
852 // Iterate from left to right, excluding default bucket.
853 for (int bucket = 0; bucket < num_buckets; ++bucket) {
854 // default value goes to left node.
855 const Eigen::VectorXf total_left_grad =
856 cum_grad[bucket] + missing_bucket_grad;
857 const Eigen::VectorXf total_left_hess =
858 cum_hess[bucket] + missing_bucket_hess;
859 MaybeUpdateBestSplit(
860 total_left_grad, total_grad - total_left_grad, total_left_hess,
861 total_hess - total_left_hess, logits_dim, bucket, f_dim, l1, l2,
862 kInequalityDefaultLeft, best_gain, best_bucket, best_f_dim,
863 best_split_type, best_contrib_for_left, best_contrib_for_right);
864 // default value goes to right node.
865 MaybeUpdateBestSplit(
866 cum_grad[bucket], total_grad - cum_grad[bucket], cum_hess[bucket],
867 total_hess - cum_hess[bucket], logits_dim, bucket, f_dim, l1, l2,
868 kInequalityDefaultRight, best_gain, best_bucket, best_f_dim,
869 best_split_type, best_contrib_for_left, best_contrib_for_right);
870 } // for bucket
871 }
872 }
873
874 // Calculate the best equality split per node.
CalculateBestEqualitySplit(TTypes<float,4>::ConstTensor stats_summary,const Eigen::VectorXf & total_grad,const Eigen::VectorXf & total_hess,const int32 node_id,const int32 feature_dims,const int32 logits_dim,const int32 hessian_dim,const int32 num_buckets,const float l1,const float l2,float * best_gain,int32 * best_bucket,int32 * best_f_dim,string * best_split_type,Eigen::VectorXf * best_contrib_for_left,Eigen::VectorXf * best_contrib_for_right)875 void CalculateBestEqualitySplit(
876 TTypes<float, 4>::ConstTensor stats_summary,
877 const Eigen::VectorXf& total_grad, const Eigen::VectorXf& total_hess,
878 const int32 node_id, const int32 feature_dims, const int32 logits_dim,
879 const int32 hessian_dim, const int32 num_buckets, const float l1,
880 const float l2, float* best_gain, int32* best_bucket, int32* best_f_dim,
881 string* best_split_type, Eigen::VectorXf* best_contrib_for_left,
882 Eigen::VectorXf* best_contrib_for_right) {
883 const string kEqualityDefaultRight =
884 boosted_trees::SplitTypeWithDefault_Name(
885 boosted_trees::EQUALITY_DEFAULT_RIGHT);
886 for (int f_dim = 0; f_dim < feature_dims; ++f_dim) {
887 for (int bucket = 0; bucket < num_buckets; ++bucket) {
888 ConstVectorMap stats_vec(&stats_summary(node_id, f_dim, bucket, 0),
889 logits_dim + hessian_dim);
890 Eigen::VectorXf curr_grad = stats_vec.head(logits_dim);
891 Eigen::VectorXf curr_hess = stats_vec.tail(hessian_dim);
892 MaybeUpdateBestSplit(curr_grad, total_grad - curr_grad, curr_hess,
893 total_hess - curr_hess, logits_dim, bucket, f_dim,
894 l1, l2, kEqualityDefaultRight, best_gain,
895 best_bucket, best_f_dim, best_split_type,
896 best_contrib_for_left, best_contrib_for_right);
897 }
898 }
899 }
900
MaybeUpdateBestSplit(const Eigen::VectorXf & grad_for_left,const Eigen::VectorXf & grad_for_right,const Eigen::VectorXf & hess_for_left,const Eigen::VectorXf & hess_for_right,const int32 logits_dim,const int32 bucket,const int32 f_dim,const float l1,const float l2,const string split_type,float * best_gain,int32 * best_bucket,int32 * best_f_dim,string * best_split_type,Eigen::VectorXf * best_contrib_for_left,Eigen::VectorXf * best_contrib_for_right)901 void MaybeUpdateBestSplit(const Eigen::VectorXf& grad_for_left,
902 const Eigen::VectorXf& grad_for_right,
903 const Eigen::VectorXf& hess_for_left,
904 const Eigen::VectorXf& hess_for_right,
905 const int32 logits_dim, const int32 bucket,
906 const int32 f_dim, const float l1, const float l2,
907 const string split_type, float* best_gain,
908 int32* best_bucket, int32* best_f_dim,
909 string* best_split_type,
910 Eigen::VectorXf* best_contrib_for_left,
911 Eigen::VectorXf* best_contrib_for_right) {
912 // Left child.
913 Eigen::VectorXf contrib_for_left(logits_dim);
914 float gain_for_left;
915 CalculateWeightsAndGains(grad_for_left, hess_for_left, l1, l2,
916 &contrib_for_left, &gain_for_left);
917 Eigen::VectorXf contrib_for_right(logits_dim);
918 float gain_for_right;
919 CalculateWeightsAndGains(grad_for_right, hess_for_right, l1, l2,
920 &contrib_for_right, &gain_for_right);
921 if (GainIsLarger(gain_for_left + gain_for_right, *best_gain)) {
922 *best_gain = gain_for_left + gain_for_right;
923 *best_bucket = bucket;
924 *best_f_dim = f_dim;
925 *best_contrib_for_left = contrib_for_left;
926 *best_contrib_for_right = contrib_for_right;
927 *best_split_type = split_type;
928 }
929 }
930 int num_features_;
931 int logits_dim_;
932 };
933
934 // v2 op that supports multi-class.
935 REGISTER_KERNEL_BUILDER(
936 Name("BoostedTreesCalculateBestFeatureSplitV2").Device(DEVICE_CPU),
937 BoostedTreesCalculateBestFeatureSplitV2);
938
939 // Map from bucket id to vector of statistics.
940 typedef std::map<int32, std::vector<float>> BucketMap;
941 typedef BucketMap::iterator BucketMapIterator;
942 // Map from feature dimension to BucketMap.
943 typedef std::map<int32, BucketMap> FeatureMap;
944 typedef FeatureMap::iterator FeatureMapIterator;
945
946 class BoostedTreesSparseCalculateBestFeatureSplitOp : public OpKernel {
947 public:
BoostedTreesSparseCalculateBestFeatureSplitOp(OpKernelConstruction * const context)948 explicit BoostedTreesSparseCalculateBestFeatureSplitOp(
949 OpKernelConstruction* const context)
950 : OpKernel(context) {
951 // TODO(crawles): Using logits_dim_ for multi-class split.
952 OP_REQUIRES_OK(context, context->GetAttr("logits_dimension", &logits_dim_));
953 // TODO(tanzheny): Using this for equality split.
954 OP_REQUIRES_OK(context, context->GetAttr("split_type", &split_type_));
955 }
956
Compute(OpKernelContext * const context)957 void Compute(OpKernelContext* const context) override {
958 // node_id_range
959 const Tensor* node_id_range_t;
960 OP_REQUIRES_OK(context, context->input("node_id_range", &node_id_range_t));
961 const auto node_id_range = node_id_range_t->vec<int32>();
962 const int32 node_id_first = node_id_range(0); // inclusive
963 const int32 node_id_last = node_id_range(1); // exclusive
964
965 const Tensor* stats_summary_indices_t;
966 OP_REQUIRES_OK(context, context->input("stats_summary_indices",
967 &stats_summary_indices_t));
968 const auto stats_summary_indices = stats_summary_indices_t->matrix<int32>();
969 const int32 num_sparse_entries = stats_summary_indices_t->dim_size(0);
970
971 const Tensor* stats_summary_values_t;
972 OP_REQUIRES_OK(context, context->input("stats_summary_values",
973 &stats_summary_values_t));
974 const auto stats_summary_values = stats_summary_values_t->vec<float>();
975
976 const Tensor* stats_summary_shape_t;
977 OP_REQUIRES_OK(
978 context, context->input("stats_summary_shape", &stats_summary_shape_t));
979 const auto stats_summary_shape = stats_summary_shape_t->vec<int32>();
980 const int32 num_buckets = stats_summary_shape(2) - 1;
981 const int32 stats_dims = stats_summary_shape(3);
982
983 const Tensor* l1_t;
984 OP_REQUIRES_OK(context, context->input("l1", &l1_t));
985 const auto l1 = l1_t->scalar<float>()();
986
987 const Tensor* l2_t;
988 OP_REQUIRES_OK(context, context->input("l2", &l2_t));
989 const auto l2 = l2_t->scalar<float>()();
990
991 const Tensor* tree_complexity_t;
992 OP_REQUIRES_OK(context,
993 context->input("tree_complexity", &tree_complexity_t));
994 const auto tree_complexity = tree_complexity_t->scalar<float>()();
995
996 const Tensor* min_node_weight_t;
997 OP_REQUIRES_OK(context,
998 context->input("min_node_weight", &min_node_weight_t));
999 const auto min_node_weight = min_node_weight_t->scalar<float>()();
1000
1001 std::vector<int32> output_node_ids;
1002 std::vector<float> output_gains;
1003 std::vector<int32> output_feature_dimensions;
1004 std::vector<int32> output_thresholds;
1005 std::vector<float> output_left_node_contribs;
1006 std::vector<float> output_right_node_contribs;
1007 std::vector<string> output_split_types;
1008
1009 FeatureMap f_map;
1010
1011 int32 previous_node_id = -1;
1012 for (int idx = 0; idx < num_sparse_entries; ++idx) {
1013 int32 node_id = stats_summary_indices(idx, 0);
1014 if (node_id != previous_node_id) {
1015 process_node(f_map, &output_node_ids, &output_gains,
1016 &output_feature_dimensions, &output_thresholds,
1017 &output_left_node_contribs, &output_right_node_contribs,
1018 &output_split_types, previous_node_id, min_node_weight, l1,
1019 l2, num_buckets);
1020 f_map.clear();
1021 }
1022 previous_node_id = node_id;
1023 DCHECK_LE(node_id_first, node_id);
1024 DCHECK_LT(node_id, node_id_last);
1025 const int32 feature_dim = stats_summary_indices(idx, 1);
1026 const int32 bucket_id = stats_summary_indices(idx, 2);
1027 const int32 stat_dim = stats_summary_indices(idx, 3);
1028 std::pair<FeatureMapIterator, bool> const& f_insert_result = f_map.insert(
1029 FeatureMapIterator::value_type(feature_dim, BucketMap()));
1030 auto& b_map = f_insert_result.first->second;
1031 std::pair<BucketMapIterator, bool> const& b_insert_result =
1032 b_map.insert(BucketMapIterator::value_type(
1033 bucket_id, std::vector<float>(stats_dims)));
1034 auto& stats = b_insert_result.first->second;
1035 stats[stat_dim] = stats_summary_values(idx);
1036 } // for node_id
1037 // process the last node id
1038 process_node(f_map, &output_node_ids, &output_gains,
1039 &output_feature_dimensions, &output_thresholds,
1040 &output_left_node_contribs, &output_right_node_contribs,
1041 &output_split_types, previous_node_id, min_node_weight, l1, l2,
1042 num_buckets);
1043
1044 const int num_nodes = output_node_ids.size();
1045 // output_node_ids
1046 Tensor* output_node_ids_t = nullptr;
1047 OP_REQUIRES_OK(context, context->allocate_output("node_ids", {num_nodes},
1048 &output_node_ids_t));
1049 auto output_node_ids_vec = output_node_ids_t->vec<int32>();
1050
1051 // output_gains
1052 Tensor* output_gains_t;
1053 OP_REQUIRES_OK(context, context->allocate_output("gains", {num_nodes},
1054 &output_gains_t));
1055 auto output_gains_vec = output_gains_t->vec<float>();
1056
1057 // output_feature_dimensions
1058 Tensor* output_feature_dimension_t;
1059 OP_REQUIRES_OK(context,
1060 context->allocate_output("feature_dimensions", {num_nodes},
1061 &output_feature_dimension_t));
1062 auto output_feature_dimensions_vec =
1063 output_feature_dimension_t->vec<int32>();
1064
1065 // output_thresholds
1066 Tensor* output_thresholds_t;
1067 OP_REQUIRES_OK(context, context->allocate_output("thresholds", {num_nodes},
1068 &output_thresholds_t));
1069 auto output_thresholds_vec = output_thresholds_t->vec<int32>();
1070
1071 // output_left_node_contribs
1072 Tensor* output_left_node_contribs_t;
1073 OP_REQUIRES_OK(
1074 context, context->allocate_output("left_node_contribs", {num_nodes, 1},
1075 &output_left_node_contribs_t));
1076 auto output_left_node_contribs_matrix =
1077 output_left_node_contribs_t->matrix<float>();
1078
1079 // output_right_node_contribs
1080 Tensor* output_right_node_contribs_t;
1081 OP_REQUIRES_OK(
1082 context, context->allocate_output("right_node_contribs", {num_nodes, 1},
1083 &output_right_node_contribs_t));
1084 auto output_right_node_contribs_matrix =
1085 output_right_node_contribs_t->matrix<float>();
1086
1087 // split type
1088 Tensor* output_split_types_t;
1089 OP_REQUIRES_OK(
1090 context, context->allocate_output("split_with_default_directions",
1091 {num_nodes}, &output_split_types_t));
1092 auto output_split_types_vec = output_split_types_t->vec<tstring>();
1093
1094 // Sets output tensors from vectors.
1095 for (int i = 0; i < num_nodes; ++i) {
1096 output_node_ids_vec(i) = output_node_ids[i];
1097 // Adjust the gains to penalize by tree complexity.
1098 output_gains_vec(i) = output_gains[i] - tree_complexity;
1099 output_feature_dimensions_vec(i) = output_feature_dimensions[i];
1100 output_thresholds_vec(i) = output_thresholds[i];
1101 // TODO(crawles): change this for multi-class.
1102 output_left_node_contribs_matrix(i, 0) = output_left_node_contribs[i];
1103 output_right_node_contribs_matrix(i, 0) = output_right_node_contribs[i];
1104 output_split_types_vec(i) = output_split_types[i];
1105 }
1106 }
1107
1108 protected:
process_node(const FeatureMap & f_map,std::vector<int32> * output_node_ids,std::vector<float> * output_gains,std::vector<int32> * output_feature_dimensions,std::vector<int32> * output_thresholds,std::vector<float> * output_left_node_contribs,std::vector<float> * output_right_node_contribs,std::vector<string> * output_split_types,const int32 node_id,const float min_node_weight,const float l1,const float l2,const int32 num_buckets)1109 void process_node(const FeatureMap& f_map,
1110 std::vector<int32>* output_node_ids,
1111 std::vector<float>* output_gains,
1112 std::vector<int32>* output_feature_dimensions,
1113 std::vector<int32>* output_thresholds,
1114 std::vector<float>* output_left_node_contribs,
1115 std::vector<float>* output_right_node_contribs,
1116 std::vector<string>* output_split_types,
1117 const int32 node_id, const float min_node_weight,
1118 const float l1, const float l2, const int32 num_buckets) {
1119 float parent_gain;
1120 Eigen::VectorXf unused(logits_dim_);
1121 Eigen::MatrixXf identity;
1122 identity.setIdentity(1, 1);
1123
1124 // start processing for previous node id.
1125 float best_gain = std::numeric_limits<float>::lowest();
1126 float best_bucket = 0;
1127 float best_f_dim = 0;
1128 string best_split_type = boosted_trees::SplitTypeWithDefault_Name(
1129 boosted_trees::INEQUALITY_DEFAULT_LEFT);
1130 float best_contrib_for_left = 0.0;
1131 float best_contrib_for_right = 0.0;
1132 // the sum of gradients including default bucket.
1133 float total_grad = 0;
1134 // the sum of hessians including default bucket.
1135 float total_hess = 0;
1136
1137 for (auto f_iter = f_map.begin(); f_iter != f_map.end(); ++f_iter) {
1138 const int32 feature_dim = f_iter->first;
1139 const auto buckets_to_stats_map = f_iter->second;
1140
1141 // The very last bucket contains stats for missing values.
1142 // TODO(crawles): use vector for multi-class.
1143 const float default_grad =
1144 (buckets_to_stats_map.find(num_buckets) == buckets_to_stats_map.end()
1145 ? 0
1146 : buckets_to_stats_map.at(num_buckets)[0]);
1147 const float default_hess =
1148 (buckets_to_stats_map.find(num_buckets) == buckets_to_stats_map.end()
1149 ? 0
1150 : buckets_to_stats_map.at(num_buckets)[1]);
1151
1152 if (f_iter == f_map.begin()) {
1153 // first get the sum of grads, including default bucket.
1154 for (auto b_iter = buckets_to_stats_map.begin();
1155 b_iter != buckets_to_stats_map.end(); ++b_iter) {
1156 total_grad += b_iter->second[0];
1157 total_hess += b_iter->second[1];
1158 }
1159 if (total_hess < min_node_weight) {
1160 // Do not split the node because not enough avg hessian.
1161 break;
1162 }
1163 CalculateWeightsAndGains(total_grad * identity, total_hess * identity,
1164 l1, l2, &unused, &parent_gain);
1165 }
1166
1167 float total_left_grad = 0;
1168 float total_left_hess = 0;
1169 for (auto b_iter = buckets_to_stats_map.begin();
1170 b_iter != buckets_to_stats_map.end(); ++b_iter) {
1171 const int32 bucket_id = b_iter->first;
1172 // total_left_stats should exclude stats from default bucket.
1173 if (bucket_id == num_buckets) {
1174 break;
1175 }
1176 // TODO(crawles): vector for multi-class.
1177 total_left_grad += b_iter->second[0];
1178 total_left_hess += b_iter->second[1];
1179 // From left to right, default right.
1180 // Left child.
1181 Eigen::VectorXf contrib_for_left(1);
1182 float gain_for_left;
1183 CalculateWeightsAndGains(total_left_grad * identity,
1184 total_left_hess * identity, l1, l2,
1185 &contrib_for_left, &gain_for_left);
1186 // Right child.
1187 Eigen::VectorXf contrib_for_right(1);
1188 float gain_for_right;
1189 CalculateWeightsAndGains((total_grad - total_left_grad) * identity,
1190 (total_hess - total_left_hess) * identity, l1,
1191 l2, &contrib_for_right, &gain_for_right);
1192 if (GainIsLarger(gain_for_left + gain_for_right, best_gain)) {
1193 best_gain = gain_for_left + gain_for_right;
1194 best_bucket = bucket_id;
1195 best_f_dim = feature_dim;
1196 best_split_type = boosted_trees::SplitTypeWithDefault_Name(
1197 boosted_trees::INEQUALITY_DEFAULT_RIGHT);
1198 best_contrib_for_left = contrib_for_left[0];
1199 best_contrib_for_right = contrib_for_right[0];
1200 }
1201
1202 // From right to left, default left.
1203 CalculateWeightsAndGains((total_left_grad + default_grad) * identity,
1204 (total_left_hess + default_hess) * identity,
1205 l1, l2, &contrib_for_left, &gain_for_left);
1206 CalculateWeightsAndGains(
1207 (total_grad - default_grad - total_left_grad) * identity,
1208 (total_hess - default_hess - total_left_hess) * identity, l1, l2,
1209 &contrib_for_right, &gain_for_right);
1210 if (GainIsLarger(gain_for_left + gain_for_right, best_gain)) {
1211 best_gain = gain_for_left + gain_for_right;
1212 best_bucket = bucket_id;
1213 best_f_dim = feature_dim;
1214 best_split_type = boosted_trees::SplitTypeWithDefault_Name(
1215 boosted_trees::INEQUALITY_DEFAULT_LEFT);
1216 best_contrib_for_left = contrib_for_left[0];
1217 best_contrib_for_right = contrib_for_right[0];
1218 }
1219 } // for bucket_id
1220 } // for feature_dim
1221 if (best_gain != std::numeric_limits<float>::lowest()) {
1222 output_node_ids->push_back(node_id);
1223 // Remove the parent gain.
1224 output_gains->push_back(best_gain - parent_gain);
1225 output_feature_dimensions->push_back(best_f_dim);
1226 output_split_types->push_back(best_split_type);
1227 output_thresholds->push_back(best_bucket);
1228 output_left_node_contribs->push_back(best_contrib_for_left);
1229 output_right_node_contribs->push_back(best_contrib_for_right);
1230 }
1231 }
1232
1233 private:
1234 int logits_dim_;
1235 string split_type_;
1236 };
1237
1238 REGISTER_KERNEL_BUILDER(
1239 Name("BoostedTreesSparseCalculateBestFeatureSplit").Device(DEVICE_CPU),
1240 BoostedTreesSparseCalculateBestFeatureSplitOp);
1241
1242 class BoostedTreesMakeStatsSummaryOp : public OpKernel {
1243 public:
BoostedTreesMakeStatsSummaryOp(OpKernelConstruction * const context)1244 explicit BoostedTreesMakeStatsSummaryOp(OpKernelConstruction* const context)
1245 : OpKernel(context) {
1246 OP_REQUIRES_OK(context, context->GetAttr("max_splits", &max_splits_));
1247 OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
1248 OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_));
1249 }
1250
Compute(OpKernelContext * const context)1251 void Compute(OpKernelContext* const context) override {
1252 // node_ids
1253 const Tensor* node_ids_t;
1254 OP_REQUIRES_OK(context, context->input("node_ids", &node_ids_t));
1255 const auto node_ids = node_ids_t->vec<int32>();
1256 // gradients
1257 const Tensor* gradients_t;
1258 OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
1259 const auto gradients = gradients_t->matrix<float>();
1260 // hessians
1261 const Tensor* hessians_t;
1262 OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
1263 const auto hessians = hessians_t->matrix<float>();
1264 // bucketized_features
1265 OpInputList bucketized_features_list;
1266 OP_REQUIRES_OK(context, context->input_list("bucketized_features_list",
1267 &bucketized_features_list));
1268 // Infer batch size.
1269 const int64 batch_size = node_ids_t->dim_size(0);
1270
1271 // Allocate temporary stats tensor (Rank 4).
1272 Tensor temp_stats_double_t;
1273 OP_REQUIRES_OK(context, context->allocate_temp(
1274 DT_DOUBLE,
1275 {num_features_, max_splits_, num_buckets_, 2},
1276 &temp_stats_double_t));
1277 auto temp_stats_double = temp_stats_double_t.tensor<double, 4>();
1278 temp_stats_double.setZero();
1279
1280 // Partition by node, and then bucketize.
1281 for (int feature_idx = 0; feature_idx < num_features_; ++feature_idx) {
1282 const auto& features = bucketized_features_list[feature_idx].vec<int32>();
1283 for (int i = 0; i < batch_size; ++i) {
1284 const int32 node = node_ids(i);
1285 const int32 bucket = features(i);
1286 temp_stats_double(feature_idx, node, bucket, 0) += gradients(i, 0);
1287 temp_stats_double(feature_idx, node, bucket, 1) += hessians(i, 0);
1288 }
1289 }
1290
1291 // Copy temp tensor over to output tensor.
1292 Tensor* output_stats_summary_t = nullptr;
1293 OP_REQUIRES_OK(context, context->allocate_output(
1294 "stats_summary", temp_stats_double_t.shape(),
1295 &output_stats_summary_t));
1296 output_stats_summary_t->tensor<float, 4>() =
1297 temp_stats_double.template cast<float>();
1298 }
1299
1300 private:
1301 int max_splits_;
1302 int num_buckets_;
1303 int num_features_;
1304 };
1305
1306 REGISTER_KERNEL_BUILDER(Name("BoostedTreesMakeStatsSummary").Device(DEVICE_CPU),
1307 BoostedTreesMakeStatsSummaryOp);
1308
1309 // TODO(tanzheny): Add an option of default value into the API interface.
1310 class BoostedTreesAggregateStatsOp : public OpKernel {
1311 public:
BoostedTreesAggregateStatsOp(OpKernelConstruction * const context)1312 explicit BoostedTreesAggregateStatsOp(OpKernelConstruction* const context)
1313 : OpKernel(context) {
1314 OP_REQUIRES_OK(context, context->GetAttr("max_splits", &max_splits_));
1315 OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
1316 }
1317
Compute(OpKernelContext * const context)1318 void Compute(OpKernelContext* const context) override {
1319 // node_ids.
1320 const Tensor* node_ids_t;
1321 OP_REQUIRES_OK(context, context->input("node_ids", &node_ids_t));
1322 const auto node_ids = node_ids_t->vec<int32>();
1323
1324 // gradients.
1325 const Tensor* gradients_t;
1326 OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
1327 const auto gradients = gradients_t->matrix<float>();
1328
1329 // hessians.
1330 const Tensor* hessians_t;
1331 OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
1332 const auto hessians = hessians_t->matrix<float>();
1333
1334 // feature.
1335 const Tensor* feature_t;
1336 OP_REQUIRES_OK(context, context->input("feature", &feature_t));
1337 const auto feature = feature_t->matrix<int32>();
1338
1339 // Infer batch size, feature dimension and stats dimension.
1340 const int64 batch_size = node_ids_t->dim_size(0);
1341 const int64 logits_dims = gradients_t->dim_size(1);
1342 const int64 hessians_dims = hessians_t->dim_size(1);
1343 const int64 stats_dims = logits_dims + hessians_dims;
1344 const int64 feature_dims = feature_t->dim_size(1);
1345
1346 // Allocate temporary stats tensor (Rank 4), upcasting to double.
1347 // A default bucket is added to the end for missing/default values.
1348 Tensor temp_stats_double_t;
1349 OP_REQUIRES_OK(
1350 context, context->allocate_temp(
1351 DT_DOUBLE,
1352 {max_splits_, feature_dims, num_buckets_ + 1, stats_dims},
1353 &temp_stats_double_t));
1354 auto temp_stats_double = temp_stats_double_t.tensor<double, 4>();
1355 temp_stats_double.setZero();
1356
1357 for (int i = 0; i < batch_size; ++i) {
1358 const int32 node = node_ids(i);
1359 for (int feature_dim = 0; feature_dim < feature_dims; ++feature_dim) {
1360 const int32 feature_value = feature(i, feature_dim);
1361 const int32 bucket =
1362 (feature_value == -1) ? num_buckets_ : feature_value;
1363 for (int stat_dim = 0; stat_dim < logits_dims; ++stat_dim) {
1364 temp_stats_double(node, feature_dim, bucket, stat_dim) +=
1365 gradients(i, stat_dim);
1366 }
1367 for (int stat_dim = logits_dims; stat_dim < stats_dims; ++stat_dim) {
1368 temp_stats_double(node, feature_dim, bucket, stat_dim) +=
1369 hessians(i, stat_dim - logits_dims);
1370 }
1371 }
1372 }
1373
1374 // Copy temp tensor over to output tensor, downcasting to float.
1375 Tensor* output_stats_summary_t = nullptr;
1376 OP_REQUIRES_OK(context, context->allocate_output(
1377 "stats_summary", temp_stats_double_t.shape(),
1378 &output_stats_summary_t));
1379 output_stats_summary_t->tensor<float, 4>() =
1380 temp_stats_double.template cast<float>();
1381 }
1382
1383 private:
1384 int max_splits_;
1385 int num_buckets_;
1386 };
1387
1388 REGISTER_KERNEL_BUILDER(Name("BoostedTreesAggregateStats").Device(DEVICE_CPU),
1389 BoostedTreesAggregateStatsOp);
1390
1391 // Key based on node id, feature dimension and bucket id.
1392 struct StatsPartitionKey {
StatsPartitionKeytensorflow::StatsPartitionKey1393 StatsPartitionKey(const int32 node_id, const int32 feature_dim,
1394 const int32 bucket_id)
1395 : node_id(node_id), feature_dim(feature_dim), bucket_id(bucket_id) {}
1396
operator ==tensorflow::StatsPartitionKey1397 bool operator==(const StatsPartitionKey& other) const {
1398 return (node_id == other.node_id) && (feature_dim == other.feature_dim) &&
1399 (bucket_id == other.bucket_id);
1400 }
1401
1402 // Compare for StatsPartitionKey.
1403 struct Less {
operator ()tensorflow::StatsPartitionKey::Less1404 bool operator()(const StatsPartitionKey& a,
1405 const StatsPartitionKey& b) const {
1406 if (a.node_id < b.node_id) {
1407 return true;
1408 }
1409 if ((a.node_id == b.node_id) && (a.feature_dim < b.feature_dim)) {
1410 return true;
1411 }
1412 if ((a.node_id == b.node_id) && (a.feature_dim == b.feature_dim) &&
1413 (a.bucket_id < b.bucket_id)) {
1414 return true;
1415 }
1416 return false;
1417 }
1418 };
1419
1420 // Tree node id.
1421 int32 node_id;
1422 // Dimension within feature column.
1423 int32 feature_dim;
1424 // bucketized feature value .
1425 int32 bucket_id;
1426 };
1427
1428 typedef std::map<StatsPartitionKey, std::vector<float>, StatsPartitionKey::Less>
1429 StatsPartitionMap;
1430 typedef StatsPartitionMap::iterator StatsPartitionIterator;
1431
1432 // Key based on instance and feature dimension.
1433 struct InstanceFeatureDimKey {
InstanceFeatureDimKeytensorflow::InstanceFeatureDimKey1434 InstanceFeatureDimKey() : instance(-1), feature_dim(-1) {}
1435
InstanceFeatureDimKeytensorflow::InstanceFeatureDimKey1436 InstanceFeatureDimKey(const int32 instance, const int32 feature_dim)
1437 : instance(instance), feature_dim(feature_dim) {}
1438
operator ==tensorflow::InstanceFeatureDimKey1439 bool operator==(const InstanceFeatureDimKey& other) const {
1440 return (instance == other.instance) && (feature_dim == other.feature_dim);
1441 }
1442
1443 // Compare for InstanceFeatureDimKey.
1444 struct Less {
operator ()tensorflow::InstanceFeatureDimKey::Less1445 bool operator()(const InstanceFeatureDimKey& a,
1446 const InstanceFeatureDimKey& b) const {
1447 if (a.instance < b.instance) {
1448 return true;
1449 }
1450 if ((a.instance == b.instance) && (a.feature_dim < b.feature_dim)) {
1451 return true;
1452 }
1453 return false;
1454 }
1455 };
1456
1457 // Instance id within a batch.
1458 int32 instance;
1459 // Dimension within feature column.
1460 int32 feature_dim;
1461 };
1462
1463 // Add statistics to StatsPartitionMap for (instance, feature dim, bucket id).
AddInstanceStatsToMap(const int32 instance,const int32 feature_dim,const int32 bucket_id,const int32 logits_dims,const int32 stats_dims,StatsPartitionMap * stats_map,const TTypes<float>::ConstMatrix & gradients,const TTypes<float>::ConstMatrix & hessians,const TTypes<int32>::ConstVec & node_ids)1464 static void AddInstanceStatsToMap(const int32 instance, const int32 feature_dim,
1465 const int32 bucket_id,
1466 const int32 logits_dims,
1467 const int32 stats_dims,
1468 StatsPartitionMap* stats_map,
1469 const TTypes<float>::ConstMatrix& gradients,
1470 const TTypes<float>::ConstMatrix& hessians,
1471 const TTypes<int32>::ConstVec& node_ids) {
1472 const int32 node_id = node_ids(instance);
1473 const auto key = StatsPartitionKey(node_id, feature_dim, bucket_id);
1474 std::pair<StatsPartitionIterator, bool> const& insert_result =
1475 stats_map->insert(StatsPartitionIterator::value_type(
1476 key, std::vector<float>(stats_dims, 0.0f)));
1477 auto& stats = insert_result.first->second;
1478 for (int stat_dim = 0; stat_dim < logits_dims; ++stat_dim) {
1479 stats[stat_dim] += gradients(instance, stat_dim);
1480 }
1481 for (int stat_dim = logits_dims; stat_dim < stats_dims; ++stat_dim) {
1482 stats[stat_dim] += hessians(instance, stat_dim - logits_dims);
1483 }
1484 }
1485
1486 // Add statistics to StatsPartitionMap for bucket_id ranging from
1487 // (start_instance, start_feature_dim) to (end_instance, end_feature_dim),
1488 // inclusive on start and end instances, exclusive on end feature dim.
AddRangeStats(const int start_instance,const int end_instance,const int start_feature_dim,const int end_feature_dim,StatsPartitionMap * stats_map,const TTypes<float>::ConstMatrix & gradients,const TTypes<float>::ConstMatrix & hessians,const TTypes<int32>::ConstVec & node_ids,const int32 feature_dims,const int32 bucket_id,const int32 logits_dims,const int32 stats_dims)1489 static void AddRangeStats(const int start_instance, const int end_instance,
1490 const int start_feature_dim,
1491 const int end_feature_dim,
1492 StatsPartitionMap* stats_map,
1493 const TTypes<float>::ConstMatrix& gradients,
1494 const TTypes<float>::ConstMatrix& hessians,
1495 const TTypes<int32>::ConstVec& node_ids,
1496 const int32 feature_dims, const int32 bucket_id,
1497 const int32 logits_dims, const int32 stats_dims) {
1498 DCHECK_LE(start_instance, end_instance);
1499 if (start_instance == end_instance) {
1500 DCHECK_LT(start_feature_dim, end_feature_dim);
1501 }
1502 for (int32 instance = start_instance; instance <= end_instance; ++instance) {
1503 const int32 start_f_dim =
1504 (instance == start_instance) ? start_feature_dim + 1 : 0;
1505 const int32 end_f_dim =
1506 (instance == end_instance) ? end_feature_dim : feature_dims;
1507 for (int32 f_dim = start_f_dim; f_dim < end_f_dim; ++f_dim) {
1508 AddInstanceStatsToMap(instance, f_dim, bucket_id, logits_dims, stats_dims,
1509 stats_map, gradients, hessians, node_ids);
1510 }
1511 }
1512 }
1513
1514 class BoostedTreesSparseAggregateStatsOp : public OpKernel {
1515 public:
BoostedTreesSparseAggregateStatsOp(OpKernelConstruction * const context)1516 explicit BoostedTreesSparseAggregateStatsOp(
1517 OpKernelConstruction* const context)
1518 : OpKernel(context) {
1519 OP_REQUIRES_OK(context, context->GetAttr("max_splits", &max_splits_));
1520 OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
1521 }
1522
Compute(OpKernelContext * const context)1523 void Compute(OpKernelContext* const context) override {
1524 // node_ids.
1525 const Tensor* node_ids_t;
1526 OP_REQUIRES_OK(context, context->input("node_ids", &node_ids_t));
1527 const auto node_ids = node_ids_t->vec<int32>();
1528
1529 // gradients.
1530 const Tensor* gradients_t;
1531 OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
1532 const auto gradients = gradients_t->matrix<float>();
1533
1534 // hessians.
1535 const Tensor* hessians_t;
1536 OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
1537 const auto hessians = hessians_t->matrix<float>();
1538
1539 // feature indices.
1540 const Tensor* feature_indices_t;
1541 OP_REQUIRES_OK(context,
1542 context->input("feature_indices", &feature_indices_t));
1543 const auto feature_indices = feature_indices_t->matrix<int32>();
1544
1545 // feature values.
1546 const Tensor* feature_values_t;
1547 OP_REQUIRES_OK(context,
1548 context->input("feature_values", &feature_values_t));
1549 const auto feature_values = feature_values_t->vec<int32>();
1550
1551 // feature shape.
1552 const Tensor* feature_shape_t;
1553 OP_REQUIRES_OK(context, context->input("feature_shape", &feature_shape_t));
1554 OP_REQUIRES(context, TensorShapeUtils::IsVector(feature_shape_t->shape()),
1555 errors::InvalidArgument(
1556 "Input shapes should be a vector but received shapes ",
1557 feature_shape_t->shape().DebugString()));
1558 const auto feature_shape = feature_shape_t->vec<int32>();
1559
1560 const int64 batch_size = gradients_t->dim_size(0);
1561 const int64 logits_dims = gradients_t->dim_size(1);
1562 const int64 hessians_dims = hessians_t->dim_size(1);
1563 const int64 stats_dims = logits_dims + hessians_dims;
1564 const int64 num_sparse_entries = feature_indices_t->dim_size(0);
1565 const int32 feature_dims = feature_shape(1);
1566 DCHECK_LE(num_sparse_entries, batch_size * feature_dims);
1567
1568 // Aggregate statistics info to map.
1569 StatsPartitionMap stats_map;
1570
1571 int prev_instance = 0;
1572 int prev_f_dim = -1;
1573
1574 for (int i = 0; i < num_sparse_entries; ++i) {
1575 // the instance number within a batch
1576 const int32 instance = feature_indices(i, 0);
1577 DCHECK_LE(instance, batch_size);
1578 DCHECK_GE(instance, prev_instance);
1579 // the node id within a tree.
1580 const int32 node_id = node_ids(instance);
1581 DCHECK_LE(node_id, max_splits_);
1582 // the feature dimension.
1583 const int32 f_dim = feature_indices(i, 1);
1584 DCHECK_LE(f_dim, feature_dims);
1585 // the bucket id of the value.
1586 const int32 bucket_id = feature_values(i);
1587 DCHECK_LE(bucket_id, num_buckets_);
1588
1589 // Add statistics for the missing entries into default bucket.
1590 // The last bucket is default bucket.
1591 const int missing_entry_bucket = num_buckets_;
1592 AddRangeStats(prev_instance, instance, prev_f_dim, f_dim, &stats_map,
1593 gradients, hessians, node_ids, feature_dims,
1594 missing_entry_bucket, logits_dims, stats_dims);
1595 prev_instance = instance;
1596 prev_f_dim = f_dim;
1597 // Add statistics for the non-missing entry into
1598 // (cur_instance, cur_f_dim, bucket_id).
1599 AddInstanceStatsToMap(instance, f_dim, bucket_id, logits_dims, stats_dims,
1600 &stats_map, gradients, hessians, node_ids);
1601 }
1602 AddRangeStats(prev_instance, batch_size - 1, prev_f_dim, feature_dims,
1603 &stats_map, gradients, hessians, node_ids, feature_dims,
1604 num_buckets_, logits_dims, stats_dims);
1605
1606 // Serialize statistics info map to tensor output.
1607 const int64 num_slots = stats_map.size() * stats_dims;
1608 Tensor* summary_indices_t = nullptr;
1609 OP_REQUIRES_OK(context,
1610 context->allocate_output("stats_summary_indices",
1611 TensorShape({num_slots, 4}),
1612 &summary_indices_t));
1613 auto summary_indices = summary_indices_t->matrix<int32>();
1614 Tensor* summary_values_t = nullptr;
1615 OP_REQUIRES_OK(context, context->allocate_output("stats_summary_values",
1616 TensorShape({num_slots}),
1617 &summary_values_t));
1618 auto summary_values = summary_values_t->vec<float>();
1619 int entry_index = 0;
1620 for (auto& iter : stats_map) {
1621 for (int stat_dim = 0; stat_dim < stats_dims; ++stat_dim) {
1622 summary_indices(entry_index, 0) = iter.first.node_id;
1623 summary_indices(entry_index, 1) = iter.first.feature_dim;
1624 summary_indices(entry_index, 2) = iter.first.bucket_id;
1625 summary_indices(entry_index, 3) = stat_dim;
1626 summary_values(entry_index) = iter.second[stat_dim];
1627 ++entry_index;
1628 }
1629 }
1630
1631 Tensor* summary_shape_t = nullptr;
1632 OP_REQUIRES_OK(
1633 context, context->allocate_output("stats_summary_shape",
1634 TensorShape({4}), &summary_shape_t));
1635 auto summary_shape = summary_shape_t->vec<int32>();
1636 summary_shape(0) = max_splits_;
1637 summary_shape(1) = feature_dims;
1638 summary_shape(2) = num_buckets_ + 1;
1639 summary_shape(3) = stats_dims;
1640 }
1641
1642 private:
1643 int max_splits_;
1644 int num_buckets_;
1645 };
1646
1647 REGISTER_KERNEL_BUILDER(
1648 Name("BoostedTreesSparseAggregateStats").Device(DEVICE_CPU),
1649 BoostedTreesSparseAggregateStatsOp);
1650
1651 } // namespace tensorflow
1652