1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <limits>
17 #include <string>
18 #include <vector>
19
20 #include "third_party/eigen3/Eigen/Core"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/tensor_shape.h"
24 #include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
25 #include "tensorflow/core/kernels/boosted_trees/tree_helper.h"
26 #include "tensorflow/core/platform/errors.h"
27 #include "tensorflow/core/platform/logging.h"
28
29 namespace tensorflow {
30
31 using Matrix =
32 Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
33 using ConstMatrixMap = Eigen::Map<const Matrix>;
34 using MatrixMap = Eigen::Map<Matrix>;
35
36 using ConstVectorMap = Eigen::Map<const Eigen::VectorXf>;
37 using VectorMap = Eigen::Map<Eigen::VectorXf>;
38
39 constexpr char kInequalitySplit[] = "inequality";
40 constexpr char kEqualitySplit[] = "equality";
41
42 // V1 Op. Deprecated. BoostedTreesCalculateBestFeatureSplitOpV2 is V2.
43 class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel {
44 public:
BoostedTreesCalculateBestGainsPerFeatureOp(OpKernelConstruction * const context)45 explicit BoostedTreesCalculateBestGainsPerFeatureOp(
46 OpKernelConstruction* const context)
47 : OpKernel(context) {
48 OP_REQUIRES_OK(context, context->GetAttr("max_splits", &max_splits_));
49 OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_));
50 }
51
Compute(OpKernelContext * const context)52 void Compute(OpKernelContext* const context) override {
53 // node_id_range
54 const Tensor* node_id_range_t;
55 OP_REQUIRES_OK(context, context->input("node_id_range", &node_id_range_t));
56 OP_REQUIRES(
57 context, node_id_range_t->dims() == 1,
58 errors::InvalidArgument("node_id_range must be a rank 1 tensor, but "
59 "given node_id_range has dims of ",
60 node_id_range_t->dims()));
61 OP_REQUIRES(context, node_id_range_t->dim_size(0) == 2,
62 errors::InvalidArgument(
63 "node_id_range must be a rank 1 tensor with shape=[2], but "
64 "given node_id_range has shape ",
65 node_id_range_t->dim_size(0), " on its first dim"));
66 const auto node_id_range = node_id_range_t->vec<int32>();
67 const int32_t node_id_first = node_id_range(0); // inclusive
68 const int32_t node_id_last = node_id_range(1); // exclusive
69 // stats_summary_list
70 OpInputList stats_summary_list;
71 OP_REQUIRES_OK(context, context->input_list("stats_summary_list",
72 &stats_summary_list));
73 const int64_t num_buckets = stats_summary_list[0].dim_size(1);
74 // Check for single logit: 1 gradient + 1 hessian value.
75 DCHECK_EQ(stats_summary_list[0].dim_size(2), 2);
76 std::vector<TTypes<float, 3>::ConstTensor> stats_summary;
77 stats_summary.reserve(stats_summary_list.size());
78 for (const auto& tensor : stats_summary_list) {
79 stats_summary.emplace_back(tensor.tensor<float, 3>());
80 }
81 const Tensor* l1_t;
82 OP_REQUIRES_OK(context, context->input("l1", &l1_t));
83 const auto l1 = l1_t->scalar<float>()();
84 const Tensor* l2_t;
85 OP_REQUIRES_OK(context, context->input("l2", &l2_t));
86 const auto l2 = l2_t->scalar<float>()();
87 const Tensor* tree_complexity_t;
88 OP_REQUIRES_OK(context,
89 context->input("tree_complexity", &tree_complexity_t));
90 const auto tree_complexity = tree_complexity_t->scalar<float>()();
91 const Tensor* min_node_weight_t;
92 OP_REQUIRES_OK(context,
93 context->input("min_node_weight", &min_node_weight_t));
94 const auto min_node_weight = min_node_weight_t->scalar<float>()();
95
96 // Allocate output lists of tensors:
97 OpOutputList output_node_ids_list;
98 OP_REQUIRES_OK(
99 context, context->output_list("node_ids_list", &output_node_ids_list));
100 OpOutputList output_gains_list;
101 OP_REQUIRES_OK(context,
102 context->output_list("gains_list", &output_gains_list));
103 OpOutputList output_thresholds_list;
104 OP_REQUIRES_OK(context, context->output_list("thresholds_list",
105 &output_thresholds_list));
106 OpOutputList output_left_node_contribs_list;
107 OP_REQUIRES_OK(context,
108 context->output_list("left_node_contribs_list",
109 &output_left_node_contribs_list));
110 OpOutputList output_right_node_contribs_list;
111 OP_REQUIRES_OK(context,
112 context->output_list("right_node_contribs_list",
113 &output_right_node_contribs_list));
114
115 // Use identity later to convert float to Eigen::Matrix type for input to
116 // CalculateWeightsAndGains. This op only supports single dimension logits.
117 Eigen::MatrixXf identity;
118 identity.setIdentity(1, 1);
119 // Get the best split info per node for each feature.
120 for (int feature_idx = 0; feature_idx < num_features_; ++feature_idx) {
121 std::vector<float> cum_grad;
122 std::vector<float> cum_hess;
123 cum_grad.reserve(num_buckets);
124 cum_hess.reserve(num_buckets);
125
126 std::vector<int32> output_node_ids;
127 std::vector<float> output_gains;
128 std::vector<int32> output_thresholds;
129 std::vector<float> output_left_node_contribs;
130 std::vector<float> output_right_node_contribs;
131 for (int node_id = node_id_first; node_id < node_id_last; ++node_id) {
132 // Calculate gains.
133 cum_grad.clear();
134 cum_hess.clear();
135 float total_grad = 0.0;
136 float total_hess = 0.0;
137 for (int bucket = 0; bucket < num_buckets; ++bucket) {
138 // TODO(nponomareva): Consider multi-dimensional gradients/hessians.
139 total_grad += stats_summary[feature_idx](node_id, bucket, 0);
140 total_hess += stats_summary[feature_idx](node_id, bucket, 1);
141 cum_grad.push_back(total_grad);
142 cum_hess.push_back(total_hess);
143 }
144 // Check if node has enough of average hessian.
145 if (total_hess < min_node_weight) {
146 // Do not split the node because not enough avg hessian.
147 continue;
148 }
149 float best_gain = std::numeric_limits<float>::lowest();
150 float best_bucket = 0;
151 float best_contrib_for_left = 0.0;
152 float best_contrib_for_right = 0.0;
153 // Parent gain.
154 float parent_gain;
155 Eigen::VectorXf unused(1);
156 CalculateWeightsAndGains(total_grad * identity, total_hess * identity,
157 l1, l2, &unused, &parent_gain);
158
159 for (int bucket = 0; bucket < num_buckets; ++bucket) {
160 const float cum_grad_bucket = cum_grad[bucket];
161 const float cum_hess_bucket = cum_hess[bucket];
162 // Left child.
163 Eigen::VectorXf contrib_for_left(1);
164 float gain_for_left;
165 CalculateWeightsAndGains(cum_grad_bucket * identity,
166 cum_hess_bucket * identity, l1, l2,
167 &contrib_for_left, &gain_for_left);
168 // Right child.
169 // use contrib_for_right.
170 Eigen::VectorXf contrib_for_right(1);
171 float gain_for_right;
172 CalculateWeightsAndGains((total_grad - cum_grad_bucket) * identity,
173 (total_hess - cum_hess_bucket) * identity,
174 l1, l2, &contrib_for_right, &gain_for_right);
175
176 if (GainIsLarger(gain_for_left + gain_for_right, best_gain)) {
177 best_gain = gain_for_left + gain_for_right;
178 best_bucket = bucket;
179 best_contrib_for_left = contrib_for_left[0];
180 best_contrib_for_right = contrib_for_right[0];
181 }
182 } // for bucket
183 output_node_ids.push_back(node_id);
184 // Remove the parent gain for the parent node.
185 output_gains.push_back(best_gain - parent_gain);
186 output_thresholds.push_back(best_bucket);
187 output_left_node_contribs.push_back(best_contrib_for_left);
188 output_right_node_contribs.push_back(best_contrib_for_right);
189 } // for node_id
190 const int num_nodes = output_node_ids.size();
191 // output_node_ids
192 Tensor* output_node_ids_t;
193 OP_REQUIRES_OK(context,
194 output_node_ids_list.allocate(feature_idx, {num_nodes},
195 &output_node_ids_t));
196 auto output_node_ids_vec = output_node_ids_t->vec<int32>();
197 // output_gains
198 Tensor* output_gains_t;
199 OP_REQUIRES_OK(context, output_gains_list.allocate(
200 feature_idx, {num_nodes}, &output_gains_t));
201 auto output_gains_vec = output_gains_t->vec<float>();
202 // output_thresholds
203 Tensor* output_thresholds_t;
204 OP_REQUIRES_OK(context,
205 output_thresholds_list.allocate(feature_idx, {num_nodes},
206 &output_thresholds_t));
207 auto output_thresholds_vec = output_thresholds_t->vec<int32>();
208 // output_left_node_contribs
209 Tensor* output_left_node_contribs_t;
210 OP_REQUIRES_OK(context, output_left_node_contribs_list.allocate(
211 feature_idx, {num_nodes, 1},
212 &output_left_node_contribs_t));
213 auto output_left_node_contribs_matrix =
214 output_left_node_contribs_t->matrix<float>();
215 // output_right_node_contribs
216 Tensor* output_right_node_contribs_t;
217 OP_REQUIRES_OK(context, output_right_node_contribs_list.allocate(
218 feature_idx, {num_nodes, 1},
219 &output_right_node_contribs_t));
220 auto output_right_node_contribs_matrix =
221 output_right_node_contribs_t->matrix<float>();
222 // Sets output tensors from vectors.
223 for (int i = 0; i < num_nodes; ++i) {
224 output_node_ids_vec(i) = output_node_ids[i];
225 // Adjust the gains to penalize by tree complexity.
226 output_gains_vec(i) = output_gains[i] - tree_complexity;
227 output_thresholds_vec(i) = output_thresholds[i];
228 output_left_node_contribs_matrix(i, 0) = output_left_node_contribs[i];
229 // This op only supports 1-dimensional logits.
230 output_right_node_contribs_matrix(i, 0) = output_right_node_contribs[i];
231 }
232 } // for f
233 }
234
235 private:
236 int max_splits_;
237 int num_features_;
238 };
239
240 // V1 op that only supports single dimensional logit.
241 REGISTER_KERNEL_BUILDER(
242 Name("BoostedTreesCalculateBestGainsPerFeature").Device(DEVICE_CPU),
243 BoostedTreesCalculateBestGainsPerFeatureOp);
244
245 // Deprecated op. Use BoostedTreesCalculateBestFeatureSplitOpV2.
246 class BoostedTreesCalculateBestFeatureSplitOp : public OpKernel {
247 public:
BoostedTreesCalculateBestFeatureSplitOp(OpKernelConstruction * const context)248 explicit BoostedTreesCalculateBestFeatureSplitOp(
249 OpKernelConstruction* const context)
250 : OpKernel(context) {
251 OP_REQUIRES_OK(context, context->GetAttr("logits_dimension", &logits_dim_));
252 OP_REQUIRES_OK(context, context->GetAttr("split_type", &split_type_));
253 }
254
Compute(OpKernelContext * const context)255 void Compute(OpKernelContext* const context) override {
256 // node_id_range
257 const Tensor* node_id_range_t;
258 OP_REQUIRES_OK(context, context->input("node_id_range", &node_id_range_t));
259 OP_REQUIRES(
260 context, node_id_range_t->NumElements() == 2,
261 errors::InvalidArgument("node_id_range argument must have shape [2]"));
262 const auto node_id_range = node_id_range_t->vec<int32>();
263 const int32_t node_id_first = node_id_range(0); // inclusive
264 const int32_t node_id_last = node_id_range(1); // exclusive
265
266 const Tensor* stats_summary_t;
267 OP_REQUIRES_OK(context, context->input("stats_summary", &stats_summary_t));
268 OP_REQUIRES(
269 context, stats_summary_t->shape().dims() == 4,
270 errors::InvalidArgument("stats_summary argument must have rank 4"));
271 TTypes<float, 4>::ConstTensor stats_summary =
272 stats_summary_t->tensor<float, 4>();
273 const int32_t feature_dims = stats_summary_t->dim_size(1);
274 // The last bucket is for default/missing value.
275 const int32_t num_buckets = stats_summary_t->dim_size(2) - 1;
276 const int32_t logits_dim = logits_dim_;
277 const int32_t hessian_dim = stats_summary_t->dim_size(3) - logits_dim;
278 DCHECK_GT(hessian_dim, 0);
279 DCHECK_LE(hessian_dim, logits_dim * logits_dim);
280
281 const Tensor* l1_t;
282 OP_REQUIRES_OK(context, context->input("l1", &l1_t));
283 OP_REQUIRES(context, l1_t->NumElements() == 1,
284 errors::InvalidArgument("l1 argument must be a scalar"));
285 const auto l1 = l1_t->scalar<float>()();
286 DCHECK_GE(l1, 0);
287 if (logits_dim_ > 1) {
288 // Multi-class L1 regularization not supported yet.
289 DCHECK_EQ(l1, 0);
290 }
291
292 const Tensor* l2_t;
293 OP_REQUIRES_OK(context, context->input("l2", &l2_t));
294 OP_REQUIRES(context, l2_t->NumElements() == 1,
295 errors::InvalidArgument("l2 argument must be a scalar"));
296 const auto l2 = l2_t->scalar<float>()();
297 DCHECK_GE(l2, 0);
298
299 const Tensor* tree_complexity_t;
300 OP_REQUIRES_OK(context,
301 context->input("tree_complexity", &tree_complexity_t));
302 OP_REQUIRES(
303 context, tree_complexity_t->NumElements() == 1,
304 errors::InvalidArgument("tree_complexity argument must be a scalar"));
305 const auto tree_complexity = tree_complexity_t->scalar<float>()();
306
307 const Tensor* min_node_weight_t;
308 OP_REQUIRES_OK(context,
309 context->input("min_node_weight", &min_node_weight_t));
310 OP_REQUIRES(
311 context, min_node_weight_t->NumElements() == 1,
312 errors::InvalidArgument("min_node_weight argument must be a scalar"));
313 const auto min_node_weight = min_node_weight_t->scalar<float>()();
314
315 std::vector<int32> output_node_ids;
316 std::vector<float> output_gains;
317 std::vector<int32> output_feature_dimensions;
318 std::vector<int32> output_thresholds;
319 std::vector<Eigen::VectorXf> output_left_node_contribs;
320 std::vector<Eigen::VectorXf> output_right_node_contribs;
321 std::vector<std::string> output_split_types;
322
323 // TODO(tanzheny) parallelize the computation.
324 // Iterate each node and find the best gain per node.
325 for (int32_t node_id = node_id_first; node_id < node_id_last; ++node_id) {
326 float best_gain = std::numeric_limits<float>::lowest();
327 int32_t best_bucket = 0;
328 int32_t best_f_dim = 0;
329 string best_split_type;
330 Eigen::VectorXf best_contrib_for_left(logits_dim);
331 Eigen::VectorXf best_contrib_for_right(logits_dim);
332 float parent_gain;
333
334 // Including default bucket.
335 ConstMatrixMap stats_mat(&stats_summary(node_id, 0, 0, 0),
336 num_buckets + 1, logits_dim + hessian_dim);
337 const Eigen::VectorXf total_grad =
338 stats_mat.leftCols(logits_dim).colwise().sum();
339 const Eigen::VectorXf total_hess =
340 stats_mat.rightCols(hessian_dim).colwise().sum();
341 if (total_hess.norm() < min_node_weight) {
342 continue;
343 }
344 Eigen::VectorXf parent_weight(logits_dim);
345 CalculateWeightsAndGains(total_grad, total_hess, l1, l2, &parent_weight,
346 &parent_gain);
347
348 if (split_type_ == "inequality") {
349 CalculateBestInequalitySplit(
350 stats_summary, node_id, feature_dims, logits_dim, hessian_dim,
351 num_buckets, min_node_weight, l1, l2, &best_gain, &best_bucket,
352 &best_f_dim, &best_split_type, &best_contrib_for_left,
353 &best_contrib_for_right);
354 } else {
355 CalculateBestEqualitySplit(
356 stats_summary, total_grad, total_hess, node_id, feature_dims,
357 logits_dim, hessian_dim, num_buckets, l1, l2, &best_gain,
358 &best_bucket, &best_f_dim, &best_split_type, &best_contrib_for_left,
359 &best_contrib_for_right);
360 }
361
362 if (best_gain == std::numeric_limits<float>::lowest()) {
363 // Do not add the node if not split if found.
364 continue;
365 }
366 output_node_ids.push_back(node_id);
367 // Remove the parent gain for the parent node.
368 output_gains.push_back(best_gain - parent_gain);
369 output_feature_dimensions.push_back(best_f_dim);
370 // default direction is fixed for dense splits.
371 // TODO(tanzheny) account for default values.
372 output_split_types.push_back(best_split_type);
373 output_thresholds.push_back(best_bucket);
374 output_left_node_contribs.push_back(best_contrib_for_left);
375 output_right_node_contribs.push_back(best_contrib_for_right);
376 } // for node id
377 const int num_nodes = output_node_ids.size();
378 // output_node_ids
379 Tensor* output_node_ids_t = nullptr;
380 OP_REQUIRES_OK(context, context->allocate_output("node_ids", {num_nodes},
381 &output_node_ids_t));
382 auto output_node_ids_vec = output_node_ids_t->vec<int32>();
383
384 // output_gains
385 Tensor* output_gains_t;
386 OP_REQUIRES_OK(context, context->allocate_output("gains", {num_nodes},
387 &output_gains_t));
388 auto output_gains_vec = output_gains_t->vec<float>();
389
390 // output_feature_dimensions
391 Tensor* output_feature_dimension_t;
392 OP_REQUIRES_OK(context,
393 context->allocate_output("feature_dimensions", {num_nodes},
394 &output_feature_dimension_t));
395 auto output_feature_dimensions_vec =
396 output_feature_dimension_t->vec<int32>();
397
398 // output_thresholds
399 Tensor* output_thresholds_t;
400 OP_REQUIRES_OK(context, context->allocate_output("thresholds", {num_nodes},
401 &output_thresholds_t));
402 auto output_thresholds_vec = output_thresholds_t->vec<int32>();
403
404 // output_left_node_contribs
405 Tensor* output_left_node_contribs_t;
406 OP_REQUIRES_OK(context, context->allocate_output(
407 "left_node_contribs", {num_nodes, logits_dim},
408 &output_left_node_contribs_t));
409 auto output_left_node_contribs_matrix =
410 output_left_node_contribs_t->matrix<float>();
411
412 // output_right_node_contribs
413 Tensor* output_right_node_contribs_t;
414 OP_REQUIRES_OK(context, context->allocate_output(
415 "right_node_contribs", {num_nodes, logits_dim},
416 &output_right_node_contribs_t));
417 auto output_right_node_contribs_matrix =
418 output_right_node_contribs_t->matrix<float>();
419
420 // split type
421 Tensor* output_split_types_t;
422 OP_REQUIRES_OK(
423 context, context->allocate_output("split_with_default_directions",
424 {num_nodes}, &output_split_types_t));
425 auto output_split_types_vec = output_split_types_t->vec<tstring>();
426
427 // Sets output tensors from vectors.
428 for (int i = 0; i < num_nodes; ++i) {
429 output_node_ids_vec(i) = output_node_ids[i];
430 // Adjust the gains to penalize by tree complexity.
431 output_gains_vec(i) = output_gains[i] - tree_complexity;
432 output_feature_dimensions_vec(i) = output_feature_dimensions[i];
433 output_thresholds_vec(i) = output_thresholds[i];
434 for (int j = 0; j < logits_dim; ++j) {
435 output_left_node_contribs_matrix(i, j) =
436 output_left_node_contribs[i][j];
437 output_right_node_contribs_matrix(i, j) =
438 output_right_node_contribs[i][j];
439 }
440 output_split_types_vec(i) = output_split_types[i];
441 }
442 }
443
444 private:
445 // TODO(crawles): Simplify inequality path just like equality b/138329196
446 // Currently this is not simplify-able due to numerical instability in math
447 // i.e. gain = -g.transpose() * hessian_and_reg.colPivHouseholderQr().solve(g)
448 // It caused gain to be Inf when g is approaching 0 but not exactly 0 while
449 // there is no regularization.
450 // Calculate the best inequality split per node.
CalculateBestInequalitySplit(TTypes<float,4>::ConstTensor stats_summary,const int32_t node_id,const int32_t feature_dims,const int32_t logits_dim,const int32_t hessian_dim,const int32_t num_buckets,const float min_node_weight,const float l1,const float l2,float * best_gain,int32 * best_bucket,int32 * best_f_dim,string * best_split_type,Eigen::VectorXf * best_contrib_for_left,Eigen::VectorXf * best_contrib_for_right)451 void CalculateBestInequalitySplit(
452 TTypes<float, 4>::ConstTensor stats_summary, const int32_t node_id,
453 const int32_t feature_dims, const int32_t logits_dim,
454 const int32_t hessian_dim, const int32_t num_buckets,
455 const float min_node_weight, const float l1, const float l2,
456 float* best_gain, int32* best_bucket, int32* best_f_dim,
457 string* best_split_type, Eigen::VectorXf* best_contrib_for_left,
458 Eigen::VectorXf* best_contrib_for_right) {
459 std::vector<Eigen::VectorXf> cum_grad;
460 std::vector<Eigen::VectorXf> cum_hess;
461 // get all cumulative gradients including default bucket.
462 cum_grad.reserve(num_buckets);
463 cum_hess.reserve(num_buckets);
464
465 for (int f_dim = 0; f_dim < feature_dims; ++f_dim) {
466 ConstVectorMap default_stats_vec(
467 &stats_summary(node_id, f_dim, num_buckets, 0),
468 logits_dim + hessian_dim);
469 Eigen::VectorXf missing_bucket_grad = default_stats_vec.head(logits_dim);
470 Eigen::VectorXf missing_bucket_hess = default_stats_vec.tail(hessian_dim);
471 cum_grad.clear();
472 cum_hess.clear();
473 Eigen::VectorXf total_grad = Eigen::VectorXf::Zero(logits_dim);
474 Eigen::VectorXf total_hess = Eigen::VectorXf::Zero(hessian_dim);
475 // sum all the gradients including default bucket.
476 for (int bucket = 0; bucket <= num_buckets; ++bucket) {
477 for (int i = 0; i < logits_dim; ++i) {
478 total_grad[i] += stats_summary(node_id, f_dim, bucket, i);
479 }
480 for (int i = 0; i < hessian_dim; ++i) {
481 // Full hessian.
482 total_hess[i] +=
483 stats_summary(node_id, f_dim, bucket, logits_dim + i);
484 }
485 if (bucket < num_buckets) {
486 cum_grad.push_back(total_grad);
487 cum_hess.push_back(total_hess);
488 }
489 }
490 const string kInequalityDefaultLeft =
491 boosted_trees::SplitTypeWithDefault_Name(
492 boosted_trees::INEQUALITY_DEFAULT_LEFT);
493 const string kInequalityDefaultRight =
494 boosted_trees::SplitTypeWithDefault_Name(
495 boosted_trees::INEQUALITY_DEFAULT_RIGHT);
496
497 // Iterate from left to right, excluding default bucket.
498 for (int bucket = 0; bucket < num_buckets; ++bucket) {
499 // default value goes to left node.
500 const Eigen::VectorXf total_left_grad =
501 cum_grad[bucket] + missing_bucket_grad;
502 const Eigen::VectorXf total_left_hess =
503 cum_hess[bucket] + missing_bucket_hess;
504 MaybeUpdateBestSplit(
505 total_left_grad, total_grad - total_left_grad, total_left_hess,
506 total_hess - total_left_hess, logits_dim, bucket, f_dim, l1, l2,
507 kInequalityDefaultLeft, best_gain, best_bucket, best_f_dim,
508 best_split_type, best_contrib_for_left, best_contrib_for_right);
509 // default value goes to right node.
510 MaybeUpdateBestSplit(
511 cum_grad[bucket], total_grad - cum_grad[bucket], cum_hess[bucket],
512 total_hess - cum_hess[bucket], logits_dim, bucket, f_dim, l1, l2,
513 kInequalityDefaultRight, best_gain, best_bucket, best_f_dim,
514 best_split_type, best_contrib_for_left, best_contrib_for_right);
515 } // for bucket
516 }
517 }
518
519 // Calculate the best equality split per node.
CalculateBestEqualitySplit(TTypes<float,4>::ConstTensor stats_summary,const Eigen::VectorXf & total_grad,const Eigen::VectorXf & total_hess,const int32_t node_id,const int32_t feature_dims,const int32_t logits_dim,const int32_t hessian_dim,const int32_t num_buckets,const float l1,const float l2,float * best_gain,int32 * best_bucket,int32 * best_f_dim,string * best_split_type,Eigen::VectorXf * best_contrib_for_left,Eigen::VectorXf * best_contrib_for_right)520 void CalculateBestEqualitySplit(
521 TTypes<float, 4>::ConstTensor stats_summary,
522 const Eigen::VectorXf& total_grad, const Eigen::VectorXf& total_hess,
523 const int32_t node_id, const int32_t feature_dims,
524 const int32_t logits_dim, const int32_t hessian_dim,
525 const int32_t num_buckets, const float l1, const float l2,
526 float* best_gain, int32* best_bucket, int32* best_f_dim,
527 string* best_split_type, Eigen::VectorXf* best_contrib_for_left,
528 Eigen::VectorXf* best_contrib_for_right) {
529 const string kEqualityDefaultRight =
530 boosted_trees::SplitTypeWithDefault_Name(
531 boosted_trees::EQUALITY_DEFAULT_RIGHT);
532 for (int f_dim = 0; f_dim < feature_dims; ++f_dim) {
533 for (int bucket = 0; bucket < num_buckets; ++bucket) {
534 ConstVectorMap stats_vec(&stats_summary(node_id, f_dim, bucket, 0),
535 logits_dim + hessian_dim);
536 Eigen::VectorXf curr_grad = stats_vec.head(logits_dim);
537 Eigen::VectorXf curr_hess = stats_vec.tail(hessian_dim);
538 MaybeUpdateBestSplit(curr_grad, total_grad - curr_grad, curr_hess,
539 total_hess - curr_hess, logits_dim, bucket, f_dim,
540 l1, l2, kEqualityDefaultRight, best_gain,
541 best_bucket, best_f_dim, best_split_type,
542 best_contrib_for_left, best_contrib_for_right);
543 }
544 }
545 }
546
MaybeUpdateBestSplit(const Eigen::VectorXf & grad_for_left,const Eigen::VectorXf & grad_for_right,const Eigen::VectorXf & hess_for_left,const Eigen::VectorXf & hess_for_right,const int32_t logits_dim,const int32_t bucket,const int32_t f_dim,const float l1,const float l2,const string split_type,float * best_gain,int32 * best_bucket,int32 * best_f_dim,string * best_split_type,Eigen::VectorXf * best_contrib_for_left,Eigen::VectorXf * best_contrib_for_right)547 void MaybeUpdateBestSplit(const Eigen::VectorXf& grad_for_left,
548 const Eigen::VectorXf& grad_for_right,
549 const Eigen::VectorXf& hess_for_left,
550 const Eigen::VectorXf& hess_for_right,
551 const int32_t logits_dim, const int32_t bucket,
552 const int32_t f_dim, const float l1, const float l2,
553 const string split_type, float* best_gain,
554 int32* best_bucket, int32* best_f_dim,
555 string* best_split_type,
556 Eigen::VectorXf* best_contrib_for_left,
557 Eigen::VectorXf* best_contrib_for_right) {
558 // Left child.
559 Eigen::VectorXf contrib_for_left(logits_dim);
560 float gain_for_left;
561 CalculateWeightsAndGains(grad_for_left, hess_for_left, l1, l2,
562 &contrib_for_left, &gain_for_left);
563 Eigen::VectorXf contrib_for_right(logits_dim);
564 float gain_for_right;
565 CalculateWeightsAndGains(grad_for_right, hess_for_right, l1, l2,
566 &contrib_for_right, &gain_for_right);
567 if (GainIsLarger(gain_for_left + gain_for_right, *best_gain)) {
568 *best_gain = gain_for_left + gain_for_right;
569 *best_bucket = bucket;
570 *best_f_dim = f_dim;
571 *best_contrib_for_left = contrib_for_left;
572 *best_contrib_for_right = contrib_for_right;
573 *best_split_type = split_type;
574 }
575 }
576
577 int logits_dim_;
578 string split_type_;
579 };
580
581 // Deprecated op. Use BoostedTreesCalculateBestFeatureSplitOpV2.
582 REGISTER_KERNEL_BUILDER(
583 Name("BoostedTreesCalculateBestFeatureSplit").Device(DEVICE_CPU),
584 BoostedTreesCalculateBestFeatureSplitOp);
585
586 // V2 Op.
587 class BoostedTreesCalculateBestFeatureSplitV2 : public OpKernel {
588 public:
BoostedTreesCalculateBestFeatureSplitV2(OpKernelConstruction * const context)589 explicit BoostedTreesCalculateBestFeatureSplitV2(
590 OpKernelConstruction* const context)
591 : OpKernel(context) {
592 OP_REQUIRES_OK(context, context->GetAttr("logits_dimension", &logits_dim_));
593 OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_));
594 }
595
Compute(OpKernelContext * const context)596 void Compute(OpKernelContext* const context) override {
597 // node_id_range
598 const Tensor* node_id_range_t;
599 OP_REQUIRES_OK(context, context->input("node_id_range", &node_id_range_t));
600 const auto node_id_range = node_id_range_t->vec<int32>();
601 OP_REQUIRES(
602 context, node_id_range_t->dims() == 1,
603 errors::InvalidArgument("node_id_range must be a rank 1 tensor, but "
604 "given node_id_range has dims of ",
605 node_id_range_t->dims()));
606 OP_REQUIRES(context, node_id_range_t->dim_size(0) == 2,
607 errors::InvalidArgument(
608 "node_id_range must be a rank 1 tensor with shape=[2], but "
609 "given node_id_range has shape ",
610 node_id_range_t->dim_size(0), " on its first dim"));
611 const int32_t node_id_first = node_id_range(0); // Inclusive.
612 const int32_t node_id_last = node_id_range(1); // Exclusive.
613
614 // Get stats_summaries_list.
615 OpInputList stats_summaries_list;
616 OP_REQUIRES_OK(context, context->input_list("stats_summaries_list",
617 &stats_summaries_list));
618
619 // Infer dimensions of a stats_summary.
620 DCHECK_GT(stats_summaries_list.size(), 0);
621 const int32_t feature_dims = stats_summaries_list[0].dim_size(1);
622 // The last bucket is for default/missing value.
623 const int32_t num_buckets = stats_summaries_list[0].dim_size(2) - 1;
624 const int32_t logits_dim = logits_dim_;
625 const int32_t hessian_dim =
626 stats_summaries_list[0].dim_size(3) - logits_dim;
627 DCHECK_GT(hessian_dim, 0);
628 DCHECK_LE(hessian_dim, logits_dim * logits_dim);
629
630 // Vector of stats_summaries; each element is stats for feature of shape
631 // [max_splits, feature_dim, num_buckets, logits_dim + hessian_dim].
632 std::vector<TTypes<float, 4>::ConstTensor> stats_summaries;
633 DCHECK_EQ(stats_summaries_list.size(), num_features_);
634 stats_summaries.reserve(num_features_);
635 for (const auto& tensor : stats_summaries_list) {
636 stats_summaries.emplace_back(tensor.tensor<float, 4>());
637 }
638
639 // Split types.
640 const Tensor* split_types_t;
641 OP_REQUIRES_OK(context, context->input("split_types", &split_types_t));
642 const auto split_types = split_types_t->vec<tstring>();
643 DCHECK_EQ(split_types.size(), num_features_);
644 // Validate.
645 for (int i = 0; i < num_features_; ++i) {
646 if (!(split_types(i) == kInequalitySplit ||
647 split_types(i) == kEqualitySplit)) {
648 OP_REQUIRES_OK(
649 context,
650 errors::Aborted(
651 "Operation received an exception: Incorrect split type"));
652 }
653 }
654 // Feature ids.
655 const Tensor* candidate_feature_ids_t;
656 OP_REQUIRES_OK(context, context->input("candidate_feature_ids",
657 &candidate_feature_ids_t));
658 const auto candidate_feature_ids = candidate_feature_ids_t->vec<int32>();
659 DCHECK_EQ(candidate_feature_ids.size(), num_features_);
660
661 // L1, L2, tree_complexity, min_node_weight.
662 const Tensor* l1_t;
663 OP_REQUIRES_OK(context, context->input("l1", &l1_t));
664 const auto l1 = l1_t->scalar<float>()();
665 DCHECK_GE(l1, 0);
666 if (logits_dim_ > 1) {
667 // Multi-class L1 regularization not supported yet.
668 DCHECK_EQ(l1, 0);
669 }
670 const Tensor* l2_t;
671 OP_REQUIRES_OK(context, context->input("l2", &l2_t));
672 const auto l2 = l2_t->scalar<float>()();
673 DCHECK_GE(l2, 0);
674 const Tensor* tree_complexity_t;
675 OP_REQUIRES_OK(context,
676 context->input("tree_complexity", &tree_complexity_t));
677 const auto tree_complexity = tree_complexity_t->scalar<float>()();
678 const Tensor* min_node_weight_t;
679 OP_REQUIRES_OK(context,
680 context->input("min_node_weight", &min_node_weight_t));
681 const auto min_node_weight = min_node_weight_t->scalar<float>()();
682
683 std::vector<int32> output_node_ids;
684 std::vector<float> output_gains;
685 std::vector<int32> output_feature_ids;
686 std::vector<int32> output_feature_dimensions;
687 std::vector<int32> output_thresholds;
688 std::vector<Eigen::VectorXf> output_left_node_contribs;
689 std::vector<Eigen::VectorXf> output_right_node_contribs;
690 std::vector<string> output_split_types;
691
692 // TODO(tanzheny) parallelize the computation.
693 // Iterate each node and find the best gain per node.
694 float parent_gain;
695 for (int32_t node_id = node_id_first; node_id < node_id_last; ++node_id) {
696 float best_gain = std::numeric_limits<float>::lowest();
697 int32_t best_bucket;
698 int32_t best_f_id;
699 int32_t best_f_dim;
700 string best_split_type;
701 Eigen::VectorXf best_contrib_for_left(logits_dim);
702 Eigen::VectorXf best_contrib_for_right(logits_dim);
703
704 // Sum of gradient and hessian. Compute parent gain using first feature.
705 ConstMatrixMap stats_mat(&stats_summaries[0](node_id, 0, 0, 0),
706 num_buckets + 1, // Including default bucket.
707 logits_dim + hessian_dim);
708 const Eigen::VectorXf total_grad =
709 stats_mat.leftCols(logits_dim).colwise().sum();
710 const Eigen::VectorXf total_hess =
711 stats_mat.rightCols(hessian_dim).colwise().sum();
712 if (total_hess.norm() < min_node_weight) {
713 continue;
714 }
715 Eigen::VectorXf unused(logits_dim);
716 CalculateWeightsAndGains(total_grad, total_hess, l1, l2, &unused,
717 &parent_gain);
718 for (int f_idx = 0; f_idx < num_features_; ++f_idx) {
719 const string split_type = split_types(f_idx);
720 TTypes<float, 4>::ConstTensor stats_summary = stats_summaries[f_idx];
721 float f_best_gain = std::numeric_limits<float>::lowest();
722 int32_t f_best_bucket;
723 int32_t f_best_f_dim;
724 string f_best_split_type;
725 Eigen::VectorXf f_best_contrib_for_left(logits_dim);
726 Eigen::VectorXf f_best_contrib_for_right(logits_dim);
727
728 if (split_type == kInequalitySplit) {
729 CalculateBestInequalitySplit(
730 stats_summary, node_id, feature_dims, logits_dim, hessian_dim,
731 num_buckets, min_node_weight, l1, l2, &f_best_gain,
732 &f_best_bucket, &f_best_f_dim, &f_best_split_type,
733 &f_best_contrib_for_left, &f_best_contrib_for_right);
734 } else {
735 CalculateBestEqualitySplit(
736 stats_summary, total_grad, total_hess, node_id, feature_dims,
737 logits_dim, hessian_dim, num_buckets, l1, l2, &f_best_gain,
738 &f_best_bucket, &f_best_f_dim, &f_best_split_type,
739 &f_best_contrib_for_left, &f_best_contrib_for_right);
740 }
741 if (f_best_gain > best_gain) {
742 best_gain = f_best_gain;
743 best_f_id = candidate_feature_ids(f_idx);
744 best_f_dim = f_best_f_dim;
745 best_split_type = f_best_split_type;
746 best_bucket = f_best_bucket;
747 best_contrib_for_left = f_best_contrib_for_left;
748 best_contrib_for_right = f_best_contrib_for_right;
749 }
750 } // For feature id.
751 if (best_gain == std::numeric_limits<float>::lowest()) {
752 // Do not add the node if no split is found.
753 continue;
754 }
755 output_node_ids.push_back(node_id);
756 // Remove the parent gain for the parent node.
757 output_gains.push_back(best_gain - parent_gain);
758 output_feature_ids.push_back(best_f_id);
759 output_feature_dimensions.push_back(best_f_dim);
760 // Default direction is fixed for dense splits.
761 // TODO(tanzheny) account for default values.
762 output_split_types.push_back(best_split_type);
763 output_thresholds.push_back(best_bucket);
764 output_left_node_contribs.push_back(best_contrib_for_left);
765 output_right_node_contribs.push_back(best_contrib_for_right);
766 } // for node id.
767 const int num_nodes = output_node_ids.size();
768 // output_node_ids
769 Tensor* output_node_ids_t = nullptr;
770 OP_REQUIRES_OK(context, context->allocate_output("node_ids", {num_nodes},
771 &output_node_ids_t));
772 auto output_node_ids_vec = output_node_ids_t->vec<int32>();
773
774 // output_gains
775 Tensor* output_gains_t;
776 OP_REQUIRES_OK(context, context->allocate_output("gains", {num_nodes},
777 &output_gains_t));
778 auto output_gains_vec = output_gains_t->vec<float>();
779
780 // output_feature_ids
781 Tensor* output_features_ids_t;
782 OP_REQUIRES_OK(context, context->allocate_output("feature_ids", {num_nodes},
783 &output_features_ids_t));
784 auto output_features_vec = output_features_ids_t->vec<int32>();
785
786 // output_feature_dimensions
787 Tensor* output_feature_dimension_t;
788 OP_REQUIRES_OK(context,
789 context->allocate_output("feature_dimensions", {num_nodes},
790 &output_feature_dimension_t));
791 auto output_feature_dimensions_vec =
792 output_feature_dimension_t->vec<int32>();
793
794 // output_thresholds
795 Tensor* output_thresholds_t;
796 OP_REQUIRES_OK(context, context->allocate_output("thresholds", {num_nodes},
797 &output_thresholds_t));
798 auto output_thresholds_vec = output_thresholds_t->vec<int32>();
799
800 // output_left_node_contribs
801 Tensor* output_left_node_contribs_t;
802 OP_REQUIRES_OK(context, context->allocate_output(
803 "left_node_contribs", {num_nodes, logits_dim},
804 &output_left_node_contribs_t));
805 auto output_left_node_contribs_matrix =
806 output_left_node_contribs_t->matrix<float>();
807
808 // output_right_node_contribs
809 Tensor* output_right_node_contribs_t;
810 OP_REQUIRES_OK(context, context->allocate_output(
811 "right_node_contribs", {num_nodes, logits_dim},
812 &output_right_node_contribs_t));
813 auto output_right_node_contribs_matrix =
814 output_right_node_contribs_t->matrix<float>();
815
816 // split type
817 Tensor* output_split_types_t;
818 OP_REQUIRES_OK(
819 context, context->allocate_output("split_with_default_directions",
820 {num_nodes}, &output_split_types_t));
821 auto output_split_types_vec = output_split_types_t->vec<tstring>();
822
823 // Sets output tensors from vectors.
824 for (int i = 0; i < num_nodes; ++i) {
825 output_node_ids_vec(i) = output_node_ids[i];
826 output_features_vec(i) = output_feature_ids[i];
827 // Adjust the gains to penalize by tree complexity.
828 output_gains_vec(i) = output_gains[i] - tree_complexity;
829 output_feature_dimensions_vec(i) = output_feature_dimensions[i];
830 output_thresholds_vec(i) = output_thresholds[i];
831 for (int j = 0; j < logits_dim; ++j) {
832 output_left_node_contribs_matrix(i, j) =
833 output_left_node_contribs[i][j];
834 output_right_node_contribs_matrix(i, j) =
835 output_right_node_contribs[i][j];
836 }
837 output_split_types_vec(i) = output_split_types[i];
838 }
839 }
840
841 private:
842 // TODO(crawles): Simplify inequality path just like equality b/138329196
843 // Currently this is not simplify-able due to numerical instability in math
844 // i.e. gain = -g.transpose() * hessian_and_reg.colPivHouseholderQr().solve(g)
845 // It caused gain to be Inf when g is approaching 0 but not exactly 0 while
846 // there is no regularization.
847 // Calculate the best inequality split per node.
CalculateBestInequalitySplit(TTypes<float,4>::ConstTensor stats_summary,const int32_t node_id,const int32_t feature_dims,const int32_t logits_dim,const int32_t hessian_dim,const int32_t num_buckets,const float min_node_weight,const float l1,const float l2,float * best_gain,int32 * best_bucket,int32 * best_f_dim,string * best_split_type,Eigen::VectorXf * best_contrib_for_left,Eigen::VectorXf * best_contrib_for_right)848 void CalculateBestInequalitySplit(
849 TTypes<float, 4>::ConstTensor stats_summary, const int32_t node_id,
850 const int32_t feature_dims, const int32_t logits_dim,
851 const int32_t hessian_dim, const int32_t num_buckets,
852 const float min_node_weight, const float l1, const float l2,
853 float* best_gain, int32* best_bucket, int32* best_f_dim,
854 string* best_split_type, Eigen::VectorXf* best_contrib_for_left,
855 Eigen::VectorXf* best_contrib_for_right) {
856 std::vector<Eigen::VectorXf> cum_grad;
857 std::vector<Eigen::VectorXf> cum_hess;
858 // get all cumulative gradients including default bucket.
859 cum_grad.reserve(num_buckets);
860 cum_hess.reserve(num_buckets);
861
862 for (int f_dim = 0; f_dim < feature_dims; ++f_dim) {
863 ConstVectorMap default_stats_vec(
864 &stats_summary(node_id, f_dim, num_buckets, 0),
865 logits_dim + hessian_dim);
866 Eigen::VectorXf missing_bucket_grad = default_stats_vec.head(logits_dim);
867 Eigen::VectorXf missing_bucket_hess = default_stats_vec.tail(hessian_dim);
868 cum_grad.clear();
869 cum_hess.clear();
870 Eigen::VectorXf total_grad = Eigen::VectorXf::Zero(logits_dim);
871 Eigen::VectorXf total_hess = Eigen::VectorXf::Zero(hessian_dim);
872 // sum all the gradients including default bucket.
873 for (int bucket = 0; bucket <= num_buckets; ++bucket) {
874 for (int i = 0; i < logits_dim; ++i) {
875 total_grad[i] += stats_summary(node_id, f_dim, bucket, i);
876 }
877 for (int i = 0; i < hessian_dim; ++i) {
878 // Full hessian.
879 total_hess[i] +=
880 stats_summary(node_id, f_dim, bucket, logits_dim + i);
881 }
882 if (bucket < num_buckets) {
883 cum_grad.push_back(total_grad);
884 cum_hess.push_back(total_hess);
885 }
886 }
887 const string kInequalityDefaultLeft =
888 boosted_trees::SplitTypeWithDefault_Name(
889 boosted_trees::INEQUALITY_DEFAULT_LEFT);
890 const string kInequalityDefaultRight =
891 boosted_trees::SplitTypeWithDefault_Name(
892 boosted_trees::INEQUALITY_DEFAULT_RIGHT);
893
894 // Iterate from left to right, excluding default bucket.
895 for (int bucket = 0; bucket < num_buckets; ++bucket) {
896 // default value goes to left node.
897 const Eigen::VectorXf total_left_grad =
898 cum_grad[bucket] + missing_bucket_grad;
899 const Eigen::VectorXf total_left_hess =
900 cum_hess[bucket] + missing_bucket_hess;
901 MaybeUpdateBestSplit(
902 total_left_grad, total_grad - total_left_grad, total_left_hess,
903 total_hess - total_left_hess, logits_dim, bucket, f_dim, l1, l2,
904 kInequalityDefaultLeft, best_gain, best_bucket, best_f_dim,
905 best_split_type, best_contrib_for_left, best_contrib_for_right);
906 // default value goes to right node.
907 MaybeUpdateBestSplit(
908 cum_grad[bucket], total_grad - cum_grad[bucket], cum_hess[bucket],
909 total_hess - cum_hess[bucket], logits_dim, bucket, f_dim, l1, l2,
910 kInequalityDefaultRight, best_gain, best_bucket, best_f_dim,
911 best_split_type, best_contrib_for_left, best_contrib_for_right);
912 } // for bucket
913 }
914 }
915
916 // Calculate the best equality split per node.
CalculateBestEqualitySplit(TTypes<float,4>::ConstTensor stats_summary,const Eigen::VectorXf & total_grad,const Eigen::VectorXf & total_hess,const int32_t node_id,const int32_t feature_dims,const int32_t logits_dim,const int32_t hessian_dim,const int32_t num_buckets,const float l1,const float l2,float * best_gain,int32 * best_bucket,int32 * best_f_dim,string * best_split_type,Eigen::VectorXf * best_contrib_for_left,Eigen::VectorXf * best_contrib_for_right)917 void CalculateBestEqualitySplit(
918 TTypes<float, 4>::ConstTensor stats_summary,
919 const Eigen::VectorXf& total_grad, const Eigen::VectorXf& total_hess,
920 const int32_t node_id, const int32_t feature_dims,
921 const int32_t logits_dim, const int32_t hessian_dim,
922 const int32_t num_buckets, const float l1, const float l2,
923 float* best_gain, int32* best_bucket, int32* best_f_dim,
924 string* best_split_type, Eigen::VectorXf* best_contrib_for_left,
925 Eigen::VectorXf* best_contrib_for_right) {
926 const string kEqualityDefaultRight =
927 boosted_trees::SplitTypeWithDefault_Name(
928 boosted_trees::EQUALITY_DEFAULT_RIGHT);
929 for (int f_dim = 0; f_dim < feature_dims; ++f_dim) {
930 for (int bucket = 0; bucket < num_buckets; ++bucket) {
931 ConstVectorMap stats_vec(&stats_summary(node_id, f_dim, bucket, 0),
932 logits_dim + hessian_dim);
933 Eigen::VectorXf curr_grad = stats_vec.head(logits_dim);
934 Eigen::VectorXf curr_hess = stats_vec.tail(hessian_dim);
935 MaybeUpdateBestSplit(curr_grad, total_grad - curr_grad, curr_hess,
936 total_hess - curr_hess, logits_dim, bucket, f_dim,
937 l1, l2, kEqualityDefaultRight, best_gain,
938 best_bucket, best_f_dim, best_split_type,
939 best_contrib_for_left, best_contrib_for_right);
940 }
941 }
942 }
943
MaybeUpdateBestSplit(const Eigen::VectorXf & grad_for_left,const Eigen::VectorXf & grad_for_right,const Eigen::VectorXf & hess_for_left,const Eigen::VectorXf & hess_for_right,const int32_t logits_dim,const int32_t bucket,const int32_t f_dim,const float l1,const float l2,const string split_type,float * best_gain,int32 * best_bucket,int32 * best_f_dim,string * best_split_type,Eigen::VectorXf * best_contrib_for_left,Eigen::VectorXf * best_contrib_for_right)944 void MaybeUpdateBestSplit(const Eigen::VectorXf& grad_for_left,
945 const Eigen::VectorXf& grad_for_right,
946 const Eigen::VectorXf& hess_for_left,
947 const Eigen::VectorXf& hess_for_right,
948 const int32_t logits_dim, const int32_t bucket,
949 const int32_t f_dim, const float l1, const float l2,
950 const string split_type, float* best_gain,
951 int32* best_bucket, int32* best_f_dim,
952 string* best_split_type,
953 Eigen::VectorXf* best_contrib_for_left,
954 Eigen::VectorXf* best_contrib_for_right) {
955 // Left child.
956 Eigen::VectorXf contrib_for_left(logits_dim);
957 float gain_for_left;
958 CalculateWeightsAndGains(grad_for_left, hess_for_left, l1, l2,
959 &contrib_for_left, &gain_for_left);
960 Eigen::VectorXf contrib_for_right(logits_dim);
961 float gain_for_right;
962 CalculateWeightsAndGains(grad_for_right, hess_for_right, l1, l2,
963 &contrib_for_right, &gain_for_right);
964 if (GainIsLarger(gain_for_left + gain_for_right, *best_gain)) {
965 *best_gain = gain_for_left + gain_for_right;
966 *best_bucket = bucket;
967 *best_f_dim = f_dim;
968 *best_contrib_for_left = contrib_for_left;
969 *best_contrib_for_right = contrib_for_right;
970 *best_split_type = split_type;
971 }
972 }
973 int num_features_;
974 int logits_dim_;
975 };
976
977 // v2 op that supports multi-class.
978 REGISTER_KERNEL_BUILDER(
979 Name("BoostedTreesCalculateBestFeatureSplitV2").Device(DEVICE_CPU),
980 BoostedTreesCalculateBestFeatureSplitV2);
981
982 // Map from bucket id to vector of statistics.
983 typedef std::map<int32, std::vector<float>> BucketMap;
984 typedef BucketMap::iterator BucketMapIterator;
985 // Map from feature dimension to BucketMap.
986 typedef std::map<int32, BucketMap> FeatureMap;
987 typedef FeatureMap::iterator FeatureMapIterator;
988
989 class BoostedTreesSparseCalculateBestFeatureSplitOp : public OpKernel {
990 public:
BoostedTreesSparseCalculateBestFeatureSplitOp(OpKernelConstruction * const context)991 explicit BoostedTreesSparseCalculateBestFeatureSplitOp(
992 OpKernelConstruction* const context)
993 : OpKernel(context) {
994 // TODO(crawles): Using logits_dim_ for multi-class split.
995 OP_REQUIRES_OK(context, context->GetAttr("logits_dimension", &logits_dim_));
996 // TODO(tanzheny): Using this for equality split.
997 OP_REQUIRES_OK(context, context->GetAttr("split_type", &split_type_));
998 }
999
Compute(OpKernelContext * const context)1000 void Compute(OpKernelContext* const context) override {
1001 // node_id_range
1002 const Tensor* node_id_range_t;
1003 OP_REQUIRES_OK(context, context->input("node_id_range", &node_id_range_t));
1004 const auto node_id_range = node_id_range_t->vec<int32>();
1005 const int32_t node_id_first = node_id_range(0); // inclusive
1006 const int32_t node_id_last = node_id_range(1); // exclusive
1007
1008 const Tensor* stats_summary_indices_t;
1009 OP_REQUIRES_OK(context, context->input("stats_summary_indices",
1010 &stats_summary_indices_t));
1011 const auto stats_summary_indices = stats_summary_indices_t->matrix<int32>();
1012 const int32_t num_sparse_entries = stats_summary_indices_t->dim_size(0);
1013
1014 const Tensor* stats_summary_values_t;
1015 OP_REQUIRES_OK(context, context->input("stats_summary_values",
1016 &stats_summary_values_t));
1017 const auto stats_summary_values = stats_summary_values_t->vec<float>();
1018
1019 const Tensor* stats_summary_shape_t;
1020 OP_REQUIRES_OK(
1021 context, context->input("stats_summary_shape", &stats_summary_shape_t));
1022 const auto stats_summary_shape = stats_summary_shape_t->vec<int32>();
1023 const int32_t num_buckets = stats_summary_shape(2) - 1;
1024 const int32_t stats_dims = stats_summary_shape(3);
1025
1026 const Tensor* l1_t;
1027 OP_REQUIRES_OK(context, context->input("l1", &l1_t));
1028 const auto l1 = l1_t->scalar<float>()();
1029
1030 const Tensor* l2_t;
1031 OP_REQUIRES_OK(context, context->input("l2", &l2_t));
1032 const auto l2 = l2_t->scalar<float>()();
1033
1034 const Tensor* tree_complexity_t;
1035 OP_REQUIRES_OK(context,
1036 context->input("tree_complexity", &tree_complexity_t));
1037 const auto tree_complexity = tree_complexity_t->scalar<float>()();
1038
1039 const Tensor* min_node_weight_t;
1040 OP_REQUIRES_OK(context,
1041 context->input("min_node_weight", &min_node_weight_t));
1042 const auto min_node_weight = min_node_weight_t->scalar<float>()();
1043
1044 std::vector<int32> output_node_ids;
1045 std::vector<float> output_gains;
1046 std::vector<int32> output_feature_dimensions;
1047 std::vector<int32> output_thresholds;
1048 std::vector<float> output_left_node_contribs;
1049 std::vector<float> output_right_node_contribs;
1050 std::vector<string> output_split_types;
1051
1052 FeatureMap f_map;
1053
1054 int32_t previous_node_id = -1;
1055 for (int idx = 0; idx < num_sparse_entries; ++idx) {
1056 int32_t node_id = stats_summary_indices(idx, 0);
1057 if (node_id != previous_node_id) {
1058 process_node(f_map, &output_node_ids, &output_gains,
1059 &output_feature_dimensions, &output_thresholds,
1060 &output_left_node_contribs, &output_right_node_contribs,
1061 &output_split_types, previous_node_id, min_node_weight, l1,
1062 l2, num_buckets);
1063 f_map.clear();
1064 }
1065 previous_node_id = node_id;
1066 DCHECK_LE(node_id_first, node_id);
1067 DCHECK_LT(node_id, node_id_last);
1068 const int32_t feature_dim = stats_summary_indices(idx, 1);
1069 const int32_t bucket_id = stats_summary_indices(idx, 2);
1070 const int32_t stat_dim = stats_summary_indices(idx, 3);
1071 OP_REQUIRES(context, stat_dim < stats_dims,
1072 errors::InvalidArgument(
1073 "Stat dim, the sum of logits dim and hessian dim in "
1074 "stats_summary_indices, cannot be greater than stats "
1075 "dims, the last value in stats_summary_shape, which was ",
1076 stats_dims, ". At index (", idx,
1077 ", 4), stats_summary_indices contains value ", stat_dim));
1078 std::pair<FeatureMapIterator, bool> const& f_insert_result = f_map.insert(
1079 FeatureMapIterator::value_type(feature_dim, BucketMap()));
1080 auto& b_map = f_insert_result.first->second;
1081 std::pair<BucketMapIterator, bool> const& b_insert_result =
1082 b_map.insert(BucketMapIterator::value_type(
1083 bucket_id, std::vector<float>(stats_dims)));
1084 auto& stats = b_insert_result.first->second;
1085 stats[stat_dim] = stats_summary_values(idx);
1086 } // for node_id
1087 // process the last node id
1088 process_node(f_map, &output_node_ids, &output_gains,
1089 &output_feature_dimensions, &output_thresholds,
1090 &output_left_node_contribs, &output_right_node_contribs,
1091 &output_split_types, previous_node_id, min_node_weight, l1, l2,
1092 num_buckets);
1093
1094 const int num_nodes = output_node_ids.size();
1095 // output_node_ids
1096 Tensor* output_node_ids_t = nullptr;
1097 OP_REQUIRES_OK(context, context->allocate_output("node_ids", {num_nodes},
1098 &output_node_ids_t));
1099 auto output_node_ids_vec = output_node_ids_t->vec<int32>();
1100
1101 // output_gains
1102 Tensor* output_gains_t;
1103 OP_REQUIRES_OK(context, context->allocate_output("gains", {num_nodes},
1104 &output_gains_t));
1105 auto output_gains_vec = output_gains_t->vec<float>();
1106
1107 // output_feature_dimensions
1108 Tensor* output_feature_dimension_t;
1109 OP_REQUIRES_OK(context,
1110 context->allocate_output("feature_dimensions", {num_nodes},
1111 &output_feature_dimension_t));
1112 auto output_feature_dimensions_vec =
1113 output_feature_dimension_t->vec<int32>();
1114
1115 // output_thresholds
1116 Tensor* output_thresholds_t;
1117 OP_REQUIRES_OK(context, context->allocate_output("thresholds", {num_nodes},
1118 &output_thresholds_t));
1119 auto output_thresholds_vec = output_thresholds_t->vec<int32>();
1120
1121 // output_left_node_contribs
1122 Tensor* output_left_node_contribs_t;
1123 OP_REQUIRES_OK(
1124 context, context->allocate_output("left_node_contribs", {num_nodes, 1},
1125 &output_left_node_contribs_t));
1126 auto output_left_node_contribs_matrix =
1127 output_left_node_contribs_t->matrix<float>();
1128
1129 // output_right_node_contribs
1130 Tensor* output_right_node_contribs_t;
1131 OP_REQUIRES_OK(
1132 context, context->allocate_output("right_node_contribs", {num_nodes, 1},
1133 &output_right_node_contribs_t));
1134 auto output_right_node_contribs_matrix =
1135 output_right_node_contribs_t->matrix<float>();
1136
1137 // split type
1138 Tensor* output_split_types_t;
1139 OP_REQUIRES_OK(
1140 context, context->allocate_output("split_with_default_directions",
1141 {num_nodes}, &output_split_types_t));
1142 auto output_split_types_vec = output_split_types_t->vec<tstring>();
1143
1144 // Sets output tensors from vectors.
1145 for (int i = 0; i < num_nodes; ++i) {
1146 output_node_ids_vec(i) = output_node_ids[i];
1147 // Adjust the gains to penalize by tree complexity.
1148 output_gains_vec(i) = output_gains[i] - tree_complexity;
1149 output_feature_dimensions_vec(i) = output_feature_dimensions[i];
1150 output_thresholds_vec(i) = output_thresholds[i];
1151 // TODO(crawles): change this for multi-class.
1152 output_left_node_contribs_matrix(i, 0) = output_left_node_contribs[i];
1153 output_right_node_contribs_matrix(i, 0) = output_right_node_contribs[i];
1154 output_split_types_vec(i) = output_split_types[i];
1155 }
1156 }
1157
1158 protected:
process_node(const FeatureMap & f_map,std::vector<int32> * output_node_ids,std::vector<float> * output_gains,std::vector<int32> * output_feature_dimensions,std::vector<int32> * output_thresholds,std::vector<float> * output_left_node_contribs,std::vector<float> * output_right_node_contribs,std::vector<string> * output_split_types,const int32_t node_id,const float min_node_weight,const float l1,const float l2,const int32_t num_buckets)1159 void process_node(const FeatureMap& f_map,
1160 std::vector<int32>* output_node_ids,
1161 std::vector<float>* output_gains,
1162 std::vector<int32>* output_feature_dimensions,
1163 std::vector<int32>* output_thresholds,
1164 std::vector<float>* output_left_node_contribs,
1165 std::vector<float>* output_right_node_contribs,
1166 std::vector<string>* output_split_types,
1167 const int32_t node_id, const float min_node_weight,
1168 const float l1, const float l2, const int32_t num_buckets) {
1169 float parent_gain;
1170 Eigen::VectorXf unused(logits_dim_);
1171 Eigen::MatrixXf identity;
1172 identity.setIdentity(1, 1);
1173
1174 // start processing for previous node id.
1175 float best_gain = std::numeric_limits<float>::lowest();
1176 float best_bucket = 0;
1177 float best_f_dim = 0;
1178 string best_split_type = boosted_trees::SplitTypeWithDefault_Name(
1179 boosted_trees::INEQUALITY_DEFAULT_LEFT);
1180 float best_contrib_for_left = 0.0;
1181 float best_contrib_for_right = 0.0;
1182 // the sum of gradients including default bucket.
1183 float total_grad = 0;
1184 // the sum of hessians including default bucket.
1185 float total_hess = 0;
1186
1187 for (auto f_iter = f_map.begin(); f_iter != f_map.end(); ++f_iter) {
1188 const int32_t feature_dim = f_iter->first;
1189 const auto buckets_to_stats_map = f_iter->second;
1190
1191 // The very last bucket contains stats for missing values.
1192 // TODO(crawles): use vector for multi-class.
1193 const float default_grad =
1194 (buckets_to_stats_map.find(num_buckets) == buckets_to_stats_map.end()
1195 ? 0
1196 : buckets_to_stats_map.at(num_buckets)[0]);
1197 const float default_hess =
1198 (buckets_to_stats_map.find(num_buckets) == buckets_to_stats_map.end()
1199 ? 0
1200 : buckets_to_stats_map.at(num_buckets)[1]);
1201
1202 if (f_iter == f_map.begin()) {
1203 // first get the sum of grads, including default bucket.
1204 for (auto b_iter = buckets_to_stats_map.begin();
1205 b_iter != buckets_to_stats_map.end(); ++b_iter) {
1206 total_grad += b_iter->second[0];
1207 total_hess += b_iter->second[1];
1208 }
1209 if (total_hess < min_node_weight) {
1210 // Do not split the node because not enough avg hessian.
1211 break;
1212 }
1213 CalculateWeightsAndGains(total_grad * identity, total_hess * identity,
1214 l1, l2, &unused, &parent_gain);
1215 }
1216
1217 float total_left_grad = 0;
1218 float total_left_hess = 0;
1219 for (auto b_iter = buckets_to_stats_map.begin();
1220 b_iter != buckets_to_stats_map.end(); ++b_iter) {
1221 const int32_t bucket_id = b_iter->first;
1222 // total_left_stats should exclude stats from default bucket.
1223 if (bucket_id == num_buckets) {
1224 break;
1225 }
1226 // TODO(crawles): vector for multi-class.
1227 total_left_grad += b_iter->second[0];
1228 total_left_hess += b_iter->second[1];
1229 // From left to right, default right.
1230 // Left child.
1231 Eigen::VectorXf contrib_for_left(1);
1232 float gain_for_left;
1233 CalculateWeightsAndGains(total_left_grad * identity,
1234 total_left_hess * identity, l1, l2,
1235 &contrib_for_left, &gain_for_left);
1236 // Right child.
1237 Eigen::VectorXf contrib_for_right(1);
1238 float gain_for_right;
1239 CalculateWeightsAndGains((total_grad - total_left_grad) * identity,
1240 (total_hess - total_left_hess) * identity, l1,
1241 l2, &contrib_for_right, &gain_for_right);
1242 if (GainIsLarger(gain_for_left + gain_for_right, best_gain)) {
1243 best_gain = gain_for_left + gain_for_right;
1244 best_bucket = bucket_id;
1245 best_f_dim = feature_dim;
1246 best_split_type = boosted_trees::SplitTypeWithDefault_Name(
1247 boosted_trees::INEQUALITY_DEFAULT_RIGHT);
1248 best_contrib_for_left = contrib_for_left[0];
1249 best_contrib_for_right = contrib_for_right[0];
1250 }
1251
1252 // From right to left, default left.
1253 CalculateWeightsAndGains((total_left_grad + default_grad) * identity,
1254 (total_left_hess + default_hess) * identity,
1255 l1, l2, &contrib_for_left, &gain_for_left);
1256 CalculateWeightsAndGains(
1257 (total_grad - default_grad - total_left_grad) * identity,
1258 (total_hess - default_hess - total_left_hess) * identity, l1, l2,
1259 &contrib_for_right, &gain_for_right);
1260 if (GainIsLarger(gain_for_left + gain_for_right, best_gain)) {
1261 best_gain = gain_for_left + gain_for_right;
1262 best_bucket = bucket_id;
1263 best_f_dim = feature_dim;
1264 best_split_type = boosted_trees::SplitTypeWithDefault_Name(
1265 boosted_trees::INEQUALITY_DEFAULT_LEFT);
1266 best_contrib_for_left = contrib_for_left[0];
1267 best_contrib_for_right = contrib_for_right[0];
1268 }
1269 } // for bucket_id
1270 } // for feature_dim
1271 if (best_gain != std::numeric_limits<float>::lowest()) {
1272 output_node_ids->push_back(node_id);
1273 // Remove the parent gain.
1274 output_gains->push_back(best_gain - parent_gain);
1275 output_feature_dimensions->push_back(best_f_dim);
1276 output_split_types->push_back(best_split_type);
1277 output_thresholds->push_back(best_bucket);
1278 output_left_node_contribs->push_back(best_contrib_for_left);
1279 output_right_node_contribs->push_back(best_contrib_for_right);
1280 }
1281 }
1282
1283 private:
1284 int logits_dim_;
1285 string split_type_;
1286 };
1287
1288 REGISTER_KERNEL_BUILDER(
1289 Name("BoostedTreesSparseCalculateBestFeatureSplit").Device(DEVICE_CPU),
1290 BoostedTreesSparseCalculateBestFeatureSplitOp);
1291
1292 class BoostedTreesMakeStatsSummaryOp : public OpKernel {
1293 public:
BoostedTreesMakeStatsSummaryOp(OpKernelConstruction * const context)1294 explicit BoostedTreesMakeStatsSummaryOp(OpKernelConstruction* const context)
1295 : OpKernel(context) {
1296 OP_REQUIRES_OK(context, context->GetAttr("max_splits", &max_splits_));
1297 OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
1298 OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_));
1299 }
1300
Compute(OpKernelContext * const context)1301 void Compute(OpKernelContext* const context) override {
1302 // node_ids
1303 const Tensor* node_ids_t;
1304 OP_REQUIRES_OK(context, context->input("node_ids", &node_ids_t));
1305 const auto node_ids = node_ids_t->vec<int32>();
1306 // gradients
1307 const Tensor* gradients_t;
1308 OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
1309 const auto gradients = gradients_t->matrix<float>();
1310 // hessians
1311 const Tensor* hessians_t;
1312 OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
1313 const auto hessians = hessians_t->matrix<float>();
1314 // bucketized_features
1315 OpInputList bucketized_features_list;
1316 OP_REQUIRES_OK(context, context->input_list("bucketized_features_list",
1317 &bucketized_features_list));
1318 // Infer batch size.
1319 const int64_t batch_size = node_ids_t->dim_size(0);
1320
1321 // Allocate temporary stats tensor (Rank 4).
1322 Tensor temp_stats_double_t;
1323 OP_REQUIRES_OK(context, context->allocate_temp(
1324 DT_DOUBLE,
1325 {num_features_, max_splits_, num_buckets_, 2},
1326 &temp_stats_double_t));
1327 auto temp_stats_double = temp_stats_double_t.tensor<double, 4>();
1328 temp_stats_double.setZero();
1329
1330 // Partition by node, and then bucketize.
1331 for (int feature_idx = 0; feature_idx < num_features_; ++feature_idx) {
1332 const auto& features = bucketized_features_list[feature_idx].vec<int32>();
1333 for (int i = 0; i < batch_size; ++i) {
1334 const int32_t node = node_ids(i);
1335 const int32_t bucket = features(i);
1336 temp_stats_double(feature_idx, node, bucket, 0) += gradients(i, 0);
1337 temp_stats_double(feature_idx, node, bucket, 1) += hessians(i, 0);
1338 }
1339 }
1340
1341 // Copy temp tensor over to output tensor.
1342 Tensor* output_stats_summary_t = nullptr;
1343 OP_REQUIRES_OK(context, context->allocate_output(
1344 "stats_summary", temp_stats_double_t.shape(),
1345 &output_stats_summary_t));
1346 output_stats_summary_t->tensor<float, 4>() =
1347 temp_stats_double.template cast<float>();
1348 }
1349
1350 private:
1351 int max_splits_;
1352 int num_buckets_;
1353 int num_features_;
1354 };
1355
1356 REGISTER_KERNEL_BUILDER(Name("BoostedTreesMakeStatsSummary").Device(DEVICE_CPU),
1357 BoostedTreesMakeStatsSummaryOp);
1358
1359 // TODO(tanzheny): Add an option of default value into the API interface.
1360 class BoostedTreesAggregateStatsOp : public OpKernel {
1361 public:
BoostedTreesAggregateStatsOp(OpKernelConstruction * const context)1362 explicit BoostedTreesAggregateStatsOp(OpKernelConstruction* const context)
1363 : OpKernel(context) {
1364 OP_REQUIRES_OK(context, context->GetAttr("max_splits", &max_splits_));
1365 OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
1366 }
1367
Compute(OpKernelContext * const context)1368 void Compute(OpKernelContext* const context) override {
1369 // node_ids.
1370 const Tensor* node_ids_t;
1371 OP_REQUIRES_OK(context, context->input("node_ids", &node_ids_t));
1372 const auto node_ids = node_ids_t->vec<int32>();
1373
1374 // gradients.
1375 const Tensor* gradients_t;
1376 OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
1377 const auto gradients = gradients_t->matrix<float>();
1378
1379 // hessians.
1380 const Tensor* hessians_t;
1381 OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
1382 const auto hessians = hessians_t->matrix<float>();
1383
1384 // feature.
1385 const Tensor* feature_t;
1386 OP_REQUIRES_OK(context, context->input("feature", &feature_t));
1387 const auto feature = feature_t->matrix<int32>();
1388
1389 // Infer batch size, feature dimension and stats dimension.
1390 const int64_t batch_size = node_ids_t->dim_size(0);
1391 const int64_t logits_dims = gradients_t->dim_size(1);
1392 const int64_t hessians_dims = hessians_t->dim_size(1);
1393 const int64_t stats_dims = logits_dims + hessians_dims;
1394 const int64_t feature_dims = feature_t->dim_size(1);
1395
1396 // Allocate temporary stats tensor (Rank 4), upcasting to double.
1397 // A default bucket is added to the end for missing/default values.
1398 Tensor temp_stats_double_t;
1399 OP_REQUIRES_OK(
1400 context, context->allocate_temp(
1401 DT_DOUBLE,
1402 {max_splits_, feature_dims, num_buckets_ + 1, stats_dims},
1403 &temp_stats_double_t));
1404 auto temp_stats_double = temp_stats_double_t.tensor<double, 4>();
1405 temp_stats_double.setZero();
1406
1407 for (int i = 0; i < batch_size; ++i) {
1408 const int32_t node = node_ids(i);
1409 for (int feature_dim = 0; feature_dim < feature_dims; ++feature_dim) {
1410 const int32_t feature_value = feature(i, feature_dim);
1411 const int32_t bucket =
1412 (feature_value == -1) ? num_buckets_ : feature_value;
1413 for (int stat_dim = 0; stat_dim < logits_dims; ++stat_dim) {
1414 temp_stats_double(node, feature_dim, bucket, stat_dim) +=
1415 gradients(i, stat_dim);
1416 }
1417 for (int stat_dim = logits_dims; stat_dim < stats_dims; ++stat_dim) {
1418 temp_stats_double(node, feature_dim, bucket, stat_dim) +=
1419 hessians(i, stat_dim - logits_dims);
1420 }
1421 }
1422 }
1423
1424 // Copy temp tensor over to output tensor, downcasting to float.
1425 Tensor* output_stats_summary_t = nullptr;
1426 OP_REQUIRES_OK(context, context->allocate_output(
1427 "stats_summary", temp_stats_double_t.shape(),
1428 &output_stats_summary_t));
1429 output_stats_summary_t->tensor<float, 4>() =
1430 temp_stats_double.template cast<float>();
1431 }
1432
1433 private:
1434 int max_splits_;
1435 int num_buckets_;
1436 };
1437
1438 REGISTER_KERNEL_BUILDER(Name("BoostedTreesAggregateStats").Device(DEVICE_CPU),
1439 BoostedTreesAggregateStatsOp);
1440
1441 // Key based on node id, feature dimension and bucket id.
1442 struct StatsPartitionKey {
StatsPartitionKeytensorflow::StatsPartitionKey1443 StatsPartitionKey(const int32_t node_id, const int32_t feature_dim,
1444 const int32_t bucket_id)
1445 : node_id(node_id), feature_dim(feature_dim), bucket_id(bucket_id) {}
1446
operator ==tensorflow::StatsPartitionKey1447 bool operator==(const StatsPartitionKey& other) const {
1448 return (node_id == other.node_id) && (feature_dim == other.feature_dim) &&
1449 (bucket_id == other.bucket_id);
1450 }
1451
1452 // Compare for StatsPartitionKey.
1453 struct Less {
operator ()tensorflow::StatsPartitionKey::Less1454 bool operator()(const StatsPartitionKey& a,
1455 const StatsPartitionKey& b) const {
1456 if (a.node_id < b.node_id) {
1457 return true;
1458 }
1459 if ((a.node_id == b.node_id) && (a.feature_dim < b.feature_dim)) {
1460 return true;
1461 }
1462 if ((a.node_id == b.node_id) && (a.feature_dim == b.feature_dim) &&
1463 (a.bucket_id < b.bucket_id)) {
1464 return true;
1465 }
1466 return false;
1467 }
1468 };
1469
1470 // Tree node id.
1471 int32 node_id;
1472 // Dimension within feature column.
1473 int32 feature_dim;
1474 // bucketized feature value .
1475 int32 bucket_id;
1476 };
1477
1478 typedef std::map<StatsPartitionKey, std::vector<float>, StatsPartitionKey::Less>
1479 StatsPartitionMap;
1480 typedef StatsPartitionMap::iterator StatsPartitionIterator;
1481
1482 // Key based on instance and feature dimension.
1483 struct InstanceFeatureDimKey {
InstanceFeatureDimKeytensorflow::InstanceFeatureDimKey1484 InstanceFeatureDimKey() : instance(-1), feature_dim(-1) {}
1485
InstanceFeatureDimKeytensorflow::InstanceFeatureDimKey1486 InstanceFeatureDimKey(const int32_t instance, const int32_t feature_dim)
1487 : instance(instance), feature_dim(feature_dim) {}
1488
operator ==tensorflow::InstanceFeatureDimKey1489 bool operator==(const InstanceFeatureDimKey& other) const {
1490 return (instance == other.instance) && (feature_dim == other.feature_dim);
1491 }
1492
1493 // Compare for InstanceFeatureDimKey.
1494 struct Less {
operator ()tensorflow::InstanceFeatureDimKey::Less1495 bool operator()(const InstanceFeatureDimKey& a,
1496 const InstanceFeatureDimKey& b) const {
1497 if (a.instance < b.instance) {
1498 return true;
1499 }
1500 if ((a.instance == b.instance) && (a.feature_dim < b.feature_dim)) {
1501 return true;
1502 }
1503 return false;
1504 }
1505 };
1506
1507 // Instance id within a batch.
1508 int32 instance;
1509 // Dimension within feature column.
1510 int32 feature_dim;
1511 };
1512
1513 // Add statistics to StatsPartitionMap for (instance, feature dim, bucket id).
AddInstanceStatsToMap(const int32_t instance,const int32_t feature_dim,const int32_t bucket_id,const int32_t logits_dims,const int32_t stats_dims,StatsPartitionMap * stats_map,const TTypes<float>::ConstMatrix & gradients,const TTypes<float>::ConstMatrix & hessians,const TTypes<int32>::ConstVec & node_ids)1514 static void AddInstanceStatsToMap(
1515 const int32_t instance, const int32_t feature_dim, const int32_t bucket_id,
1516 const int32_t logits_dims, const int32_t stats_dims,
1517 StatsPartitionMap* stats_map, const TTypes<float>::ConstMatrix& gradients,
1518 const TTypes<float>::ConstMatrix& hessians,
1519 const TTypes<int32>::ConstVec& node_ids) {
1520 const int32_t node_id = node_ids(instance);
1521 const auto key = StatsPartitionKey(node_id, feature_dim, bucket_id);
1522 std::pair<StatsPartitionIterator, bool> const& insert_result =
1523 stats_map->insert(StatsPartitionIterator::value_type(
1524 key, std::vector<float>(stats_dims, 0.0f)));
1525 auto& stats = insert_result.first->second;
1526 for (int stat_dim = 0; stat_dim < logits_dims; ++stat_dim) {
1527 stats[stat_dim] += gradients(instance, stat_dim);
1528 }
1529 for (int stat_dim = logits_dims; stat_dim < stats_dims; ++stat_dim) {
1530 stats[stat_dim] += hessians(instance, stat_dim - logits_dims);
1531 }
1532 }
1533
1534 // Add statistics to StatsPartitionMap for bucket_id ranging from
1535 // (start_instance, start_feature_dim) to (end_instance, end_feature_dim),
1536 // inclusive on start and end instances, exclusive on end feature dim.
AddRangeStats(const int start_instance,const int end_instance,const int start_feature_dim,const int end_feature_dim,StatsPartitionMap * stats_map,const TTypes<float>::ConstMatrix & gradients,const TTypes<float>::ConstMatrix & hessians,const TTypes<int32>::ConstVec & node_ids,const int32_t feature_dims,const int32_t bucket_id,const int32_t logits_dims,const int32_t stats_dims)1537 static void AddRangeStats(const int start_instance, const int end_instance,
1538 const int start_feature_dim,
1539 const int end_feature_dim,
1540 StatsPartitionMap* stats_map,
1541 const TTypes<float>::ConstMatrix& gradients,
1542 const TTypes<float>::ConstMatrix& hessians,
1543 const TTypes<int32>::ConstVec& node_ids,
1544 const int32_t feature_dims, const int32_t bucket_id,
1545 const int32_t logits_dims, const int32_t stats_dims) {
1546 DCHECK_LE(start_instance, end_instance);
1547 if (start_instance == end_instance) {
1548 DCHECK_LT(start_feature_dim, end_feature_dim);
1549 }
1550 for (int32_t instance = start_instance; instance <= end_instance;
1551 ++instance) {
1552 const int32_t start_f_dim =
1553 (instance == start_instance) ? start_feature_dim + 1 : 0;
1554 const int32_t end_f_dim =
1555 (instance == end_instance) ? end_feature_dim : feature_dims;
1556 for (int32_t f_dim = start_f_dim; f_dim < end_f_dim; ++f_dim) {
1557 AddInstanceStatsToMap(instance, f_dim, bucket_id, logits_dims, stats_dims,
1558 stats_map, gradients, hessians, node_ids);
1559 }
1560 }
1561 }
1562
1563 class BoostedTreesSparseAggregateStatsOp : public OpKernel {
1564 public:
BoostedTreesSparseAggregateStatsOp(OpKernelConstruction * const context)1565 explicit BoostedTreesSparseAggregateStatsOp(
1566 OpKernelConstruction* const context)
1567 : OpKernel(context) {
1568 OP_REQUIRES_OK(context, context->GetAttr("max_splits", &max_splits_));
1569 OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
1570 }
1571
Compute(OpKernelContext * const context)1572 void Compute(OpKernelContext* const context) override {
1573 // node_ids.
1574 const Tensor* node_ids_t;
1575 OP_REQUIRES_OK(context, context->input("node_ids", &node_ids_t));
1576 const auto node_ids = node_ids_t->vec<int32>();
1577
1578 // gradients.
1579 const Tensor* gradients_t;
1580 OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
1581 const auto gradients = gradients_t->matrix<float>();
1582
1583 // hessians.
1584 const Tensor* hessians_t;
1585 OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
1586 const auto hessians = hessians_t->matrix<float>();
1587
1588 // feature indices.
1589 const Tensor* feature_indices_t;
1590 OP_REQUIRES_OK(context,
1591 context->input("feature_indices", &feature_indices_t));
1592 const auto feature_indices = feature_indices_t->matrix<int32>();
1593
1594 // feature values.
1595 const Tensor* feature_values_t;
1596 OP_REQUIRES_OK(context,
1597 context->input("feature_values", &feature_values_t));
1598 const auto feature_values = feature_values_t->vec<int32>();
1599
1600 // feature shape.
1601 const Tensor* feature_shape_t;
1602 OP_REQUIRES_OK(context, context->input("feature_shape", &feature_shape_t));
1603 OP_REQUIRES(context, TensorShapeUtils::IsVector(feature_shape_t->shape()),
1604 errors::InvalidArgument(
1605 "Input shapes should be a vector but received shapes ",
1606 feature_shape_t->shape().DebugString()));
1607 const auto feature_shape = feature_shape_t->vec<int32>();
1608
1609 const int64_t batch_size = gradients_t->dim_size(0);
1610 const int64_t logits_dims = gradients_t->dim_size(1);
1611 const int64_t hessians_dims = hessians_t->dim_size(1);
1612 const int64_t stats_dims = logits_dims + hessians_dims;
1613 const int64_t num_sparse_entries = feature_indices_t->dim_size(0);
1614 const int32_t feature_dims = feature_shape(1);
1615 DCHECK_LE(num_sparse_entries, batch_size * feature_dims);
1616
1617 // Aggregate statistics info to map.
1618 StatsPartitionMap stats_map;
1619
1620 int prev_instance = 0;
1621 int prev_f_dim = -1;
1622
1623 for (int i = 0; i < num_sparse_entries; ++i) {
1624 // the instance number within a batch
1625 const int32_t instance = feature_indices(i, 0);
1626 DCHECK_LE(instance, batch_size);
1627 DCHECK_GE(instance, prev_instance);
1628 // the node id within a tree.
1629 const int32_t node_id = node_ids(instance);
1630 DCHECK_LE(node_id, max_splits_);
1631 // the feature dimension.
1632 const int32_t f_dim = feature_indices(i, 1);
1633 DCHECK_LE(f_dim, feature_dims);
1634 // the bucket id of the value.
1635 const int32_t bucket_id = feature_values(i);
1636 DCHECK_LE(bucket_id, num_buckets_);
1637
1638 // Add statistics for the missing entries into default bucket.
1639 // The last bucket is default bucket.
1640 const int missing_entry_bucket = num_buckets_;
1641 AddRangeStats(prev_instance, instance, prev_f_dim, f_dim, &stats_map,
1642 gradients, hessians, node_ids, feature_dims,
1643 missing_entry_bucket, logits_dims, stats_dims);
1644 prev_instance = instance;
1645 prev_f_dim = f_dim;
1646 // Add statistics for the non-missing entry into
1647 // (cur_instance, cur_f_dim, bucket_id).
1648 AddInstanceStatsToMap(instance, f_dim, bucket_id, logits_dims, stats_dims,
1649 &stats_map, gradients, hessians, node_ids);
1650 }
1651 AddRangeStats(prev_instance, batch_size - 1, prev_f_dim, feature_dims,
1652 &stats_map, gradients, hessians, node_ids, feature_dims,
1653 num_buckets_, logits_dims, stats_dims);
1654
1655 // Serialize statistics info map to tensor output.
1656 const int64_t num_slots = stats_map.size() * stats_dims;
1657 Tensor* summary_indices_t = nullptr;
1658 OP_REQUIRES_OK(context,
1659 context->allocate_output("stats_summary_indices",
1660 TensorShape({num_slots, 4}),
1661 &summary_indices_t));
1662 auto summary_indices = summary_indices_t->matrix<int32>();
1663 Tensor* summary_values_t = nullptr;
1664 OP_REQUIRES_OK(context, context->allocate_output("stats_summary_values",
1665 TensorShape({num_slots}),
1666 &summary_values_t));
1667 auto summary_values = summary_values_t->vec<float>();
1668 int entry_index = 0;
1669 for (auto& iter : stats_map) {
1670 for (int stat_dim = 0; stat_dim < stats_dims; ++stat_dim) {
1671 summary_indices(entry_index, 0) = iter.first.node_id;
1672 summary_indices(entry_index, 1) = iter.first.feature_dim;
1673 summary_indices(entry_index, 2) = iter.first.bucket_id;
1674 summary_indices(entry_index, 3) = stat_dim;
1675 summary_values(entry_index) = iter.second[stat_dim];
1676 ++entry_index;
1677 }
1678 }
1679
1680 Tensor* summary_shape_t = nullptr;
1681 OP_REQUIRES_OK(
1682 context, context->allocate_output("stats_summary_shape",
1683 TensorShape({4}), &summary_shape_t));
1684 auto summary_shape = summary_shape_t->vec<int32>();
1685 summary_shape(0) = max_splits_;
1686 summary_shape(1) = feature_dims;
1687 summary_shape(2) = num_buckets_ + 1;
1688 summary_shape(3) = stats_dims;
1689 }
1690
1691 private:
1692 int max_splits_;
1693 int num_buckets_;
1694 };
1695
1696 REGISTER_KERNEL_BUILDER(
1697 Name("BoostedTreesSparseAggregateStats").Device(DEVICE_CPU),
1698 BoostedTreesSparseAggregateStatsOp);
1699
1700 } // namespace tensorflow
1701