• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <limits>
17 #include <string>
18 #include <vector>
19 
20 #include "third_party/eigen3/Eigen/Core"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/tensor_shape.h"
24 #include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
25 #include "tensorflow/core/kernels/boosted_trees/tree_helper.h"
26 #include "tensorflow/core/platform/errors.h"
27 #include "tensorflow/core/platform/logging.h"
28 
29 namespace tensorflow {
30 
31 using Matrix =
32     Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
33 using ConstMatrixMap = Eigen::Map<const Matrix>;
34 using MatrixMap = Eigen::Map<Matrix>;
35 
36 using ConstVectorMap = Eigen::Map<const Eigen::VectorXf>;
37 using VectorMap = Eigen::Map<Eigen::VectorXf>;
38 
39 constexpr char kInequalitySplit[] = "inequality";
40 constexpr char kEqualitySplit[] = "equality";
41 
42 // V1 Op. Deprecated. BoostedTreesCalculateBestFeatureSplitOpV2 is V2.
43 class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel {
44  public:
BoostedTreesCalculateBestGainsPerFeatureOp(OpKernelConstruction * const context)45   explicit BoostedTreesCalculateBestGainsPerFeatureOp(
46       OpKernelConstruction* const context)
47       : OpKernel(context) {
48     OP_REQUIRES_OK(context, context->GetAttr("max_splits", &max_splits_));
49     OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_));
50   }
51 
Compute(OpKernelContext * const context)52   void Compute(OpKernelContext* const context) override {
53     // node_id_range
54     const Tensor* node_id_range_t;
55     OP_REQUIRES_OK(context, context->input("node_id_range", &node_id_range_t));
56     OP_REQUIRES(
57         context, node_id_range_t->dims() == 1,
58         errors::InvalidArgument("node_id_range must be a rank 1 tensor, but "
59                                 "given node_id_range has dims of ",
60                                 node_id_range_t->dims()));
61     OP_REQUIRES(context, node_id_range_t->dim_size(0) == 2,
62                 errors::InvalidArgument(
63                     "node_id_range must be a rank 1 tensor with shape=[2], but "
64                     "given node_id_range has shape ",
65                     node_id_range_t->dim_size(0), " on its first dim"));
66     const auto node_id_range = node_id_range_t->vec<int32>();
67     const int32_t node_id_first = node_id_range(0);  // inclusive
68     const int32_t node_id_last = node_id_range(1);   // exclusive
69     // stats_summary_list
70     OpInputList stats_summary_list;
71     OP_REQUIRES_OK(context, context->input_list("stats_summary_list",
72                                                 &stats_summary_list));
73     const int64_t num_buckets = stats_summary_list[0].dim_size(1);
74     // Check for single logit: 1 gradient + 1 hessian value.
75     DCHECK_EQ(stats_summary_list[0].dim_size(2), 2);
76     std::vector<TTypes<float, 3>::ConstTensor> stats_summary;
77     stats_summary.reserve(stats_summary_list.size());
78     for (const auto& tensor : stats_summary_list) {
79       stats_summary.emplace_back(tensor.tensor<float, 3>());
80     }
81     const Tensor* l1_t;
82     OP_REQUIRES_OK(context, context->input("l1", &l1_t));
83     const auto l1 = l1_t->scalar<float>()();
84     const Tensor* l2_t;
85     OP_REQUIRES_OK(context, context->input("l2", &l2_t));
86     const auto l2 = l2_t->scalar<float>()();
87     const Tensor* tree_complexity_t;
88     OP_REQUIRES_OK(context,
89                    context->input("tree_complexity", &tree_complexity_t));
90     const auto tree_complexity = tree_complexity_t->scalar<float>()();
91     const Tensor* min_node_weight_t;
92     OP_REQUIRES_OK(context,
93                    context->input("min_node_weight", &min_node_weight_t));
94     const auto min_node_weight = min_node_weight_t->scalar<float>()();
95 
96     // Allocate output lists of tensors:
97     OpOutputList output_node_ids_list;
98     OP_REQUIRES_OK(
99         context, context->output_list("node_ids_list", &output_node_ids_list));
100     OpOutputList output_gains_list;
101     OP_REQUIRES_OK(context,
102                    context->output_list("gains_list", &output_gains_list));
103     OpOutputList output_thresholds_list;
104     OP_REQUIRES_OK(context, context->output_list("thresholds_list",
105                                                  &output_thresholds_list));
106     OpOutputList output_left_node_contribs_list;
107     OP_REQUIRES_OK(context,
108                    context->output_list("left_node_contribs_list",
109                                         &output_left_node_contribs_list));
110     OpOutputList output_right_node_contribs_list;
111     OP_REQUIRES_OK(context,
112                    context->output_list("right_node_contribs_list",
113                                         &output_right_node_contribs_list));
114 
115     // Use identity later to convert float to Eigen::Matrix type for input to
116     // CalculateWeightsAndGains. This op only supports single dimension logits.
117     Eigen::MatrixXf identity;
118     identity.setIdentity(1, 1);
119     // Get the best split info per node for each feature.
120     for (int feature_idx = 0; feature_idx < num_features_; ++feature_idx) {
121       std::vector<float> cum_grad;
122       std::vector<float> cum_hess;
123       cum_grad.reserve(num_buckets);
124       cum_hess.reserve(num_buckets);
125 
126       std::vector<int32> output_node_ids;
127       std::vector<float> output_gains;
128       std::vector<int32> output_thresholds;
129       std::vector<float> output_left_node_contribs;
130       std::vector<float> output_right_node_contribs;
131       for (int node_id = node_id_first; node_id < node_id_last; ++node_id) {
132         // Calculate gains.
133         cum_grad.clear();
134         cum_hess.clear();
135         float total_grad = 0.0;
136         float total_hess = 0.0;
137         for (int bucket = 0; bucket < num_buckets; ++bucket) {
138           // TODO(nponomareva): Consider multi-dimensional gradients/hessians.
139           total_grad += stats_summary[feature_idx](node_id, bucket, 0);
140           total_hess += stats_summary[feature_idx](node_id, bucket, 1);
141           cum_grad.push_back(total_grad);
142           cum_hess.push_back(total_hess);
143         }
144         // Check if node has enough of average hessian.
145         if (total_hess < min_node_weight) {
146           // Do not split the node because not enough avg hessian.
147           continue;
148         }
149         float best_gain = std::numeric_limits<float>::lowest();
150         float best_bucket = 0;
151         float best_contrib_for_left = 0.0;
152         float best_contrib_for_right = 0.0;
153         // Parent gain.
154         float parent_gain;
155         Eigen::VectorXf unused(1);
156         CalculateWeightsAndGains(total_grad * identity, total_hess * identity,
157                                  l1, l2, &unused, &parent_gain);
158 
159         for (int bucket = 0; bucket < num_buckets; ++bucket) {
160           const float cum_grad_bucket = cum_grad[bucket];
161           const float cum_hess_bucket = cum_hess[bucket];
162           // Left child.
163           Eigen::VectorXf contrib_for_left(1);
164           float gain_for_left;
165           CalculateWeightsAndGains(cum_grad_bucket * identity,
166                                    cum_hess_bucket * identity, l1, l2,
167                                    &contrib_for_left, &gain_for_left);
168           // Right child.
169           // use contrib_for_right.
170           Eigen::VectorXf contrib_for_right(1);
171           float gain_for_right;
172           CalculateWeightsAndGains((total_grad - cum_grad_bucket) * identity,
173                                    (total_hess - cum_hess_bucket) * identity,
174                                    l1, l2, &contrib_for_right, &gain_for_right);
175 
176           if (GainIsLarger(gain_for_left + gain_for_right, best_gain)) {
177             best_gain = gain_for_left + gain_for_right;
178             best_bucket = bucket;
179             best_contrib_for_left = contrib_for_left[0];
180             best_contrib_for_right = contrib_for_right[0];
181           }
182         }  // for bucket
183         output_node_ids.push_back(node_id);
184         // Remove the parent gain for the parent node.
185         output_gains.push_back(best_gain - parent_gain);
186         output_thresholds.push_back(best_bucket);
187         output_left_node_contribs.push_back(best_contrib_for_left);
188         output_right_node_contribs.push_back(best_contrib_for_right);
189       }  // for node_id
190       const int num_nodes = output_node_ids.size();
191       // output_node_ids
192       Tensor* output_node_ids_t;
193       OP_REQUIRES_OK(context,
194                      output_node_ids_list.allocate(feature_idx, {num_nodes},
195                                                    &output_node_ids_t));
196       auto output_node_ids_vec = output_node_ids_t->vec<int32>();
197       // output_gains
198       Tensor* output_gains_t;
199       OP_REQUIRES_OK(context, output_gains_list.allocate(
200                                   feature_idx, {num_nodes}, &output_gains_t));
201       auto output_gains_vec = output_gains_t->vec<float>();
202       // output_thresholds
203       Tensor* output_thresholds_t;
204       OP_REQUIRES_OK(context,
205                      output_thresholds_list.allocate(feature_idx, {num_nodes},
206                                                      &output_thresholds_t));
207       auto output_thresholds_vec = output_thresholds_t->vec<int32>();
208       // output_left_node_contribs
209       Tensor* output_left_node_contribs_t;
210       OP_REQUIRES_OK(context, output_left_node_contribs_list.allocate(
211                                   feature_idx, {num_nodes, 1},
212                                   &output_left_node_contribs_t));
213       auto output_left_node_contribs_matrix =
214           output_left_node_contribs_t->matrix<float>();
215       // output_right_node_contribs
216       Tensor* output_right_node_contribs_t;
217       OP_REQUIRES_OK(context, output_right_node_contribs_list.allocate(
218                                   feature_idx, {num_nodes, 1},
219                                   &output_right_node_contribs_t));
220       auto output_right_node_contribs_matrix =
221           output_right_node_contribs_t->matrix<float>();
222       // Sets output tensors from vectors.
223       for (int i = 0; i < num_nodes; ++i) {
224         output_node_ids_vec(i) = output_node_ids[i];
225         // Adjust the gains to penalize by tree complexity.
226         output_gains_vec(i) = output_gains[i] - tree_complexity;
227         output_thresholds_vec(i) = output_thresholds[i];
228         output_left_node_contribs_matrix(i, 0) = output_left_node_contribs[i];
229         // This op only supports 1-dimensional logits.
230         output_right_node_contribs_matrix(i, 0) = output_right_node_contribs[i];
231       }
232     }  // for f
233   }
234 
235  private:
236   int max_splits_;
237   int num_features_;
238 };
239 
240 // V1 op that only supports single dimensional logit.
241 REGISTER_KERNEL_BUILDER(
242     Name("BoostedTreesCalculateBestGainsPerFeature").Device(DEVICE_CPU),
243     BoostedTreesCalculateBestGainsPerFeatureOp);
244 
245 // Deprecated op. Use BoostedTreesCalculateBestFeatureSplitOpV2.
246 class BoostedTreesCalculateBestFeatureSplitOp : public OpKernel {
247  public:
BoostedTreesCalculateBestFeatureSplitOp(OpKernelConstruction * const context)248   explicit BoostedTreesCalculateBestFeatureSplitOp(
249       OpKernelConstruction* const context)
250       : OpKernel(context) {
251     OP_REQUIRES_OK(context, context->GetAttr("logits_dimension", &logits_dim_));
252     OP_REQUIRES_OK(context, context->GetAttr("split_type", &split_type_));
253   }
254 
Compute(OpKernelContext * const context)255   void Compute(OpKernelContext* const context) override {
256     // node_id_range
257     const Tensor* node_id_range_t;
258     OP_REQUIRES_OK(context, context->input("node_id_range", &node_id_range_t));
259     OP_REQUIRES(
260         context, node_id_range_t->NumElements() == 2,
261         errors::InvalidArgument("node_id_range argument must have shape [2]"));
262     const auto node_id_range = node_id_range_t->vec<int32>();
263     const int32_t node_id_first = node_id_range(0);  // inclusive
264     const int32_t node_id_last = node_id_range(1);   // exclusive
265 
266     const Tensor* stats_summary_t;
267     OP_REQUIRES_OK(context, context->input("stats_summary", &stats_summary_t));
268     OP_REQUIRES(
269         context, stats_summary_t->shape().dims() == 4,
270         errors::InvalidArgument("stats_summary argument must have rank 4"));
271     TTypes<float, 4>::ConstTensor stats_summary =
272         stats_summary_t->tensor<float, 4>();
273     const int32_t feature_dims = stats_summary_t->dim_size(1);
274     // The last bucket is for default/missing value.
275     const int32_t num_buckets = stats_summary_t->dim_size(2) - 1;
276     const int32_t logits_dim = logits_dim_;
277     const int32_t hessian_dim = stats_summary_t->dim_size(3) - logits_dim;
278     DCHECK_GT(hessian_dim, 0);
279     DCHECK_LE(hessian_dim, logits_dim * logits_dim);
280 
281     const Tensor* l1_t;
282     OP_REQUIRES_OK(context, context->input("l1", &l1_t));
283     OP_REQUIRES(context, l1_t->NumElements() == 1,
284                 errors::InvalidArgument("l1 argument must be a scalar"));
285     const auto l1 = l1_t->scalar<float>()();
286     DCHECK_GE(l1, 0);
287     if (logits_dim_ > 1) {
288       // Multi-class L1 regularization not supported yet.
289       DCHECK_EQ(l1, 0);
290     }
291 
292     const Tensor* l2_t;
293     OP_REQUIRES_OK(context, context->input("l2", &l2_t));
294     OP_REQUIRES(context, l2_t->NumElements() == 1,
295                 errors::InvalidArgument("l2 argument must be a scalar"));
296     const auto l2 = l2_t->scalar<float>()();
297     DCHECK_GE(l2, 0);
298 
299     const Tensor* tree_complexity_t;
300     OP_REQUIRES_OK(context,
301                    context->input("tree_complexity", &tree_complexity_t));
302     OP_REQUIRES(
303         context, tree_complexity_t->NumElements() == 1,
304         errors::InvalidArgument("tree_complexity argument must be a scalar"));
305     const auto tree_complexity = tree_complexity_t->scalar<float>()();
306 
307     const Tensor* min_node_weight_t;
308     OP_REQUIRES_OK(context,
309                    context->input("min_node_weight", &min_node_weight_t));
310     OP_REQUIRES(
311         context, min_node_weight_t->NumElements() == 1,
312         errors::InvalidArgument("min_node_weight argument must be a scalar"));
313     const auto min_node_weight = min_node_weight_t->scalar<float>()();
314 
315     std::vector<int32> output_node_ids;
316     std::vector<float> output_gains;
317     std::vector<int32> output_feature_dimensions;
318     std::vector<int32> output_thresholds;
319     std::vector<Eigen::VectorXf> output_left_node_contribs;
320     std::vector<Eigen::VectorXf> output_right_node_contribs;
321     std::vector<std::string> output_split_types;
322 
323     // TODO(tanzheny) parallelize the computation.
324     // Iterate each node and find the best gain per node.
325     for (int32_t node_id = node_id_first; node_id < node_id_last; ++node_id) {
326       float best_gain = std::numeric_limits<float>::lowest();
327       int32_t best_bucket = 0;
328       int32_t best_f_dim = 0;
329       string best_split_type;
330       Eigen::VectorXf best_contrib_for_left(logits_dim);
331       Eigen::VectorXf best_contrib_for_right(logits_dim);
332       float parent_gain;
333 
334       // Including default bucket.
335       ConstMatrixMap stats_mat(&stats_summary(node_id, 0, 0, 0),
336                                num_buckets + 1, logits_dim + hessian_dim);
337       const Eigen::VectorXf total_grad =
338           stats_mat.leftCols(logits_dim).colwise().sum();
339       const Eigen::VectorXf total_hess =
340           stats_mat.rightCols(hessian_dim).colwise().sum();
341       if (total_hess.norm() < min_node_weight) {
342         continue;
343       }
344       Eigen::VectorXf parent_weight(logits_dim);
345       CalculateWeightsAndGains(total_grad, total_hess, l1, l2, &parent_weight,
346                                &parent_gain);
347 
348       if (split_type_ == "inequality") {
349         CalculateBestInequalitySplit(
350             stats_summary, node_id, feature_dims, logits_dim, hessian_dim,
351             num_buckets, min_node_weight, l1, l2, &best_gain, &best_bucket,
352             &best_f_dim, &best_split_type, &best_contrib_for_left,
353             &best_contrib_for_right);
354       } else {
355         CalculateBestEqualitySplit(
356             stats_summary, total_grad, total_hess, node_id, feature_dims,
357             logits_dim, hessian_dim, num_buckets, l1, l2, &best_gain,
358             &best_bucket, &best_f_dim, &best_split_type, &best_contrib_for_left,
359             &best_contrib_for_right);
360       }
361 
362       if (best_gain == std::numeric_limits<float>::lowest()) {
363         // Do not add the node if not split if found.
364         continue;
365       }
366       output_node_ids.push_back(node_id);
367       // Remove the parent gain for the parent node.
368       output_gains.push_back(best_gain - parent_gain);
369       output_feature_dimensions.push_back(best_f_dim);
370       // default direction is fixed for dense splits.
371       // TODO(tanzheny) account for default values.
372       output_split_types.push_back(best_split_type);
373       output_thresholds.push_back(best_bucket);
374       output_left_node_contribs.push_back(best_contrib_for_left);
375       output_right_node_contribs.push_back(best_contrib_for_right);
376     }  // for node id
377     const int num_nodes = output_node_ids.size();
378     // output_node_ids
379     Tensor* output_node_ids_t = nullptr;
380     OP_REQUIRES_OK(context, context->allocate_output("node_ids", {num_nodes},
381                                                      &output_node_ids_t));
382     auto output_node_ids_vec = output_node_ids_t->vec<int32>();
383 
384     // output_gains
385     Tensor* output_gains_t;
386     OP_REQUIRES_OK(context, context->allocate_output("gains", {num_nodes},
387                                                      &output_gains_t));
388     auto output_gains_vec = output_gains_t->vec<float>();
389 
390     // output_feature_dimensions
391     Tensor* output_feature_dimension_t;
392     OP_REQUIRES_OK(context,
393                    context->allocate_output("feature_dimensions", {num_nodes},
394                                             &output_feature_dimension_t));
395     auto output_feature_dimensions_vec =
396         output_feature_dimension_t->vec<int32>();
397 
398     // output_thresholds
399     Tensor* output_thresholds_t;
400     OP_REQUIRES_OK(context, context->allocate_output("thresholds", {num_nodes},
401                                                      &output_thresholds_t));
402     auto output_thresholds_vec = output_thresholds_t->vec<int32>();
403 
404     // output_left_node_contribs
405     Tensor* output_left_node_contribs_t;
406     OP_REQUIRES_OK(context, context->allocate_output(
407                                 "left_node_contribs", {num_nodes, logits_dim},
408                                 &output_left_node_contribs_t));
409     auto output_left_node_contribs_matrix =
410         output_left_node_contribs_t->matrix<float>();
411 
412     // output_right_node_contribs
413     Tensor* output_right_node_contribs_t;
414     OP_REQUIRES_OK(context, context->allocate_output(
415                                 "right_node_contribs", {num_nodes, logits_dim},
416                                 &output_right_node_contribs_t));
417     auto output_right_node_contribs_matrix =
418         output_right_node_contribs_t->matrix<float>();
419 
420     // split type
421     Tensor* output_split_types_t;
422     OP_REQUIRES_OK(
423         context, context->allocate_output("split_with_default_directions",
424                                           {num_nodes}, &output_split_types_t));
425     auto output_split_types_vec = output_split_types_t->vec<tstring>();
426 
427     // Sets output tensors from vectors.
428     for (int i = 0; i < num_nodes; ++i) {
429       output_node_ids_vec(i) = output_node_ids[i];
430       // Adjust the gains to penalize by tree complexity.
431       output_gains_vec(i) = output_gains[i] - tree_complexity;
432       output_feature_dimensions_vec(i) = output_feature_dimensions[i];
433       output_thresholds_vec(i) = output_thresholds[i];
434       for (int j = 0; j < logits_dim; ++j) {
435         output_left_node_contribs_matrix(i, j) =
436             output_left_node_contribs[i][j];
437         output_right_node_contribs_matrix(i, j) =
438             output_right_node_contribs[i][j];
439       }
440       output_split_types_vec(i) = output_split_types[i];
441     }
442   }
443 
444  private:
445   // TODO(crawles): Simplify inequality path just like equality b/138329196
446   // Currently this is not simplify-able due to numerical instability in math
447   // i.e. gain = -g.transpose() * hessian_and_reg.colPivHouseholderQr().solve(g)
448   // It caused gain to be Inf when g is approaching 0 but not exactly 0 while
449   // there is no regularization.
450   // Calculate the best inequality split per node.
CalculateBestInequalitySplit(TTypes<float,4>::ConstTensor stats_summary,const int32_t node_id,const int32_t feature_dims,const int32_t logits_dim,const int32_t hessian_dim,const int32_t num_buckets,const float min_node_weight,const float l1,const float l2,float * best_gain,int32 * best_bucket,int32 * best_f_dim,string * best_split_type,Eigen::VectorXf * best_contrib_for_left,Eigen::VectorXf * best_contrib_for_right)451   void CalculateBestInequalitySplit(
452       TTypes<float, 4>::ConstTensor stats_summary, const int32_t node_id,
453       const int32_t feature_dims, const int32_t logits_dim,
454       const int32_t hessian_dim, const int32_t num_buckets,
455       const float min_node_weight, const float l1, const float l2,
456       float* best_gain, int32* best_bucket, int32* best_f_dim,
457       string* best_split_type, Eigen::VectorXf* best_contrib_for_left,
458       Eigen::VectorXf* best_contrib_for_right) {
459     std::vector<Eigen::VectorXf> cum_grad;
460     std::vector<Eigen::VectorXf> cum_hess;
461     // get all cumulative gradients including default bucket.
462     cum_grad.reserve(num_buckets);
463     cum_hess.reserve(num_buckets);
464 
465     for (int f_dim = 0; f_dim < feature_dims; ++f_dim) {
466       ConstVectorMap default_stats_vec(
467           &stats_summary(node_id, f_dim, num_buckets, 0),
468           logits_dim + hessian_dim);
469       Eigen::VectorXf missing_bucket_grad = default_stats_vec.head(logits_dim);
470       Eigen::VectorXf missing_bucket_hess = default_stats_vec.tail(hessian_dim);
471       cum_grad.clear();
472       cum_hess.clear();
473       Eigen::VectorXf total_grad = Eigen::VectorXf::Zero(logits_dim);
474       Eigen::VectorXf total_hess = Eigen::VectorXf::Zero(hessian_dim);
475       // sum all the gradients including default bucket.
476       for (int bucket = 0; bucket <= num_buckets; ++bucket) {
477         for (int i = 0; i < logits_dim; ++i) {
478           total_grad[i] += stats_summary(node_id, f_dim, bucket, i);
479         }
480         for (int i = 0; i < hessian_dim; ++i) {
481           // Full hessian.
482           total_hess[i] +=
483               stats_summary(node_id, f_dim, bucket, logits_dim + i);
484         }
485         if (bucket < num_buckets) {
486           cum_grad.push_back(total_grad);
487           cum_hess.push_back(total_hess);
488         }
489       }
490       const string kInequalityDefaultLeft =
491           boosted_trees::SplitTypeWithDefault_Name(
492               boosted_trees::INEQUALITY_DEFAULT_LEFT);
493       const string kInequalityDefaultRight =
494           boosted_trees::SplitTypeWithDefault_Name(
495               boosted_trees::INEQUALITY_DEFAULT_RIGHT);
496 
497       // Iterate from left to right, excluding default bucket.
498       for (int bucket = 0; bucket < num_buckets; ++bucket) {
499         // default value goes to left node.
500         const Eigen::VectorXf total_left_grad =
501             cum_grad[bucket] + missing_bucket_grad;
502         const Eigen::VectorXf total_left_hess =
503             cum_hess[bucket] + missing_bucket_hess;
504         MaybeUpdateBestSplit(
505             total_left_grad, total_grad - total_left_grad, total_left_hess,
506             total_hess - total_left_hess, logits_dim, bucket, f_dim, l1, l2,
507             kInequalityDefaultLeft, best_gain, best_bucket, best_f_dim,
508             best_split_type, best_contrib_for_left, best_contrib_for_right);
509         // default value goes to right node.
510         MaybeUpdateBestSplit(
511             cum_grad[bucket], total_grad - cum_grad[bucket], cum_hess[bucket],
512             total_hess - cum_hess[bucket], logits_dim, bucket, f_dim, l1, l2,
513             kInequalityDefaultRight, best_gain, best_bucket, best_f_dim,
514             best_split_type, best_contrib_for_left, best_contrib_for_right);
515       }  // for bucket
516     }
517   }
518 
519   // Calculate the best equality split per node.
CalculateBestEqualitySplit(TTypes<float,4>::ConstTensor stats_summary,const Eigen::VectorXf & total_grad,const Eigen::VectorXf & total_hess,const int32_t node_id,const int32_t feature_dims,const int32_t logits_dim,const int32_t hessian_dim,const int32_t num_buckets,const float l1,const float l2,float * best_gain,int32 * best_bucket,int32 * best_f_dim,string * best_split_type,Eigen::VectorXf * best_contrib_for_left,Eigen::VectorXf * best_contrib_for_right)520   void CalculateBestEqualitySplit(
521       TTypes<float, 4>::ConstTensor stats_summary,
522       const Eigen::VectorXf& total_grad, const Eigen::VectorXf& total_hess,
523       const int32_t node_id, const int32_t feature_dims,
524       const int32_t logits_dim, const int32_t hessian_dim,
525       const int32_t num_buckets, const float l1, const float l2,
526       float* best_gain, int32* best_bucket, int32* best_f_dim,
527       string* best_split_type, Eigen::VectorXf* best_contrib_for_left,
528       Eigen::VectorXf* best_contrib_for_right) {
529     const string kEqualityDefaultRight =
530         boosted_trees::SplitTypeWithDefault_Name(
531             boosted_trees::EQUALITY_DEFAULT_RIGHT);
532     for (int f_dim = 0; f_dim < feature_dims; ++f_dim) {
533       for (int bucket = 0; bucket < num_buckets; ++bucket) {
534         ConstVectorMap stats_vec(&stats_summary(node_id, f_dim, bucket, 0),
535                                  logits_dim + hessian_dim);
536         Eigen::VectorXf curr_grad = stats_vec.head(logits_dim);
537         Eigen::VectorXf curr_hess = stats_vec.tail(hessian_dim);
538         MaybeUpdateBestSplit(curr_grad, total_grad - curr_grad, curr_hess,
539                              total_hess - curr_hess, logits_dim, bucket, f_dim,
540                              l1, l2, kEqualityDefaultRight, best_gain,
541                              best_bucket, best_f_dim, best_split_type,
542                              best_contrib_for_left, best_contrib_for_right);
543       }
544     }
545   }
546 
MaybeUpdateBestSplit(const Eigen::VectorXf & grad_for_left,const Eigen::VectorXf & grad_for_right,const Eigen::VectorXf & hess_for_left,const Eigen::VectorXf & hess_for_right,const int32_t logits_dim,const int32_t bucket,const int32_t f_dim,const float l1,const float l2,const string split_type,float * best_gain,int32 * best_bucket,int32 * best_f_dim,string * best_split_type,Eigen::VectorXf * best_contrib_for_left,Eigen::VectorXf * best_contrib_for_right)547   void MaybeUpdateBestSplit(const Eigen::VectorXf& grad_for_left,
548                             const Eigen::VectorXf& grad_for_right,
549                             const Eigen::VectorXf& hess_for_left,
550                             const Eigen::VectorXf& hess_for_right,
551                             const int32_t logits_dim, const int32_t bucket,
552                             const int32_t f_dim, const float l1, const float l2,
553                             const string split_type, float* best_gain,
554                             int32* best_bucket, int32* best_f_dim,
555                             string* best_split_type,
556                             Eigen::VectorXf* best_contrib_for_left,
557                             Eigen::VectorXf* best_contrib_for_right) {
558     // Left child.
559     Eigen::VectorXf contrib_for_left(logits_dim);
560     float gain_for_left;
561     CalculateWeightsAndGains(grad_for_left, hess_for_left, l1, l2,
562                              &contrib_for_left, &gain_for_left);
563     Eigen::VectorXf contrib_for_right(logits_dim);
564     float gain_for_right;
565     CalculateWeightsAndGains(grad_for_right, hess_for_right, l1, l2,
566                              &contrib_for_right, &gain_for_right);
567     if (GainIsLarger(gain_for_left + gain_for_right, *best_gain)) {
568       *best_gain = gain_for_left + gain_for_right;
569       *best_bucket = bucket;
570       *best_f_dim = f_dim;
571       *best_contrib_for_left = contrib_for_left;
572       *best_contrib_for_right = contrib_for_right;
573       *best_split_type = split_type;
574     }
575   }
576 
577   int logits_dim_;
578   string split_type_;
579 };
580 
581 // Deprecated op. Use BoostedTreesCalculateBestFeatureSplitOpV2.
582 REGISTER_KERNEL_BUILDER(
583     Name("BoostedTreesCalculateBestFeatureSplit").Device(DEVICE_CPU),
584     BoostedTreesCalculateBestFeatureSplitOp);
585 
586 // V2 Op.
587 class BoostedTreesCalculateBestFeatureSplitV2 : public OpKernel {
588  public:
BoostedTreesCalculateBestFeatureSplitV2(OpKernelConstruction * const context)589   explicit BoostedTreesCalculateBestFeatureSplitV2(
590       OpKernelConstruction* const context)
591       : OpKernel(context) {
592     OP_REQUIRES_OK(context, context->GetAttr("logits_dimension", &logits_dim_));
593     OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_));
594   }
595 
Compute(OpKernelContext * const context)596   void Compute(OpKernelContext* const context) override {
597     // node_id_range
598     const Tensor* node_id_range_t;
599     OP_REQUIRES_OK(context, context->input("node_id_range", &node_id_range_t));
600     const auto node_id_range = node_id_range_t->vec<int32>();
601     OP_REQUIRES(
602         context, node_id_range_t->dims() == 1,
603         errors::InvalidArgument("node_id_range must be a rank 1 tensor, but "
604                                 "given node_id_range has dims of ",
605                                 node_id_range_t->dims()));
606     OP_REQUIRES(context, node_id_range_t->dim_size(0) == 2,
607                 errors::InvalidArgument(
608                     "node_id_range must be a rank 1 tensor with shape=[2], but "
609                     "given node_id_range has shape ",
610                     node_id_range_t->dim_size(0), " on its first dim"));
611     const int32_t node_id_first = node_id_range(0);  // Inclusive.
612     const int32_t node_id_last = node_id_range(1);   // Exclusive.
613 
614     // Get stats_summaries_list.
615     OpInputList stats_summaries_list;
616     OP_REQUIRES_OK(context, context->input_list("stats_summaries_list",
617                                                 &stats_summaries_list));
618 
619     // Infer dimensions of a stats_summary.
620     DCHECK_GT(stats_summaries_list.size(), 0);
621     const int32_t feature_dims = stats_summaries_list[0].dim_size(1);
622     // The last bucket is for default/missing value.
623     const int32_t num_buckets = stats_summaries_list[0].dim_size(2) - 1;
624     const int32_t logits_dim = logits_dim_;
625     const int32_t hessian_dim =
626         stats_summaries_list[0].dim_size(3) - logits_dim;
627     DCHECK_GT(hessian_dim, 0);
628     DCHECK_LE(hessian_dim, logits_dim * logits_dim);
629 
630     // Vector of stats_summaries; each element is stats for feature of shape
631     // [max_splits, feature_dim, num_buckets, logits_dim + hessian_dim].
632     std::vector<TTypes<float, 4>::ConstTensor> stats_summaries;
633     DCHECK_EQ(stats_summaries_list.size(), num_features_);
634     stats_summaries.reserve(num_features_);
635     for (const auto& tensor : stats_summaries_list) {
636       stats_summaries.emplace_back(tensor.tensor<float, 4>());
637     }
638 
639     // Split types.
640     const Tensor* split_types_t;
641     OP_REQUIRES_OK(context, context->input("split_types", &split_types_t));
642     const auto split_types = split_types_t->vec<tstring>();
643     DCHECK_EQ(split_types.size(), num_features_);
644     // Validate.
645     for (int i = 0; i < num_features_; ++i) {
646       if (!(split_types(i) == kInequalitySplit ||
647             split_types(i) == kEqualitySplit)) {
648         OP_REQUIRES_OK(
649             context,
650             errors::Aborted(
651                 "Operation received an exception: Incorrect split type"));
652       }
653     }
654     // Feature ids.
655     const Tensor* candidate_feature_ids_t;
656     OP_REQUIRES_OK(context, context->input("candidate_feature_ids",
657                                            &candidate_feature_ids_t));
658     const auto candidate_feature_ids = candidate_feature_ids_t->vec<int32>();
659     DCHECK_EQ(candidate_feature_ids.size(), num_features_);
660 
661     // L1, L2, tree_complexity, min_node_weight.
662     const Tensor* l1_t;
663     OP_REQUIRES_OK(context, context->input("l1", &l1_t));
664     const auto l1 = l1_t->scalar<float>()();
665     DCHECK_GE(l1, 0);
666     if (logits_dim_ > 1) {
667       // Multi-class L1 regularization not supported yet.
668       DCHECK_EQ(l1, 0);
669     }
670     const Tensor* l2_t;
671     OP_REQUIRES_OK(context, context->input("l2", &l2_t));
672     const auto l2 = l2_t->scalar<float>()();
673     DCHECK_GE(l2, 0);
674     const Tensor* tree_complexity_t;
675     OP_REQUIRES_OK(context,
676                    context->input("tree_complexity", &tree_complexity_t));
677     const auto tree_complexity = tree_complexity_t->scalar<float>()();
678     const Tensor* min_node_weight_t;
679     OP_REQUIRES_OK(context,
680                    context->input("min_node_weight", &min_node_weight_t));
681     const auto min_node_weight = min_node_weight_t->scalar<float>()();
682 
683     std::vector<int32> output_node_ids;
684     std::vector<float> output_gains;
685     std::vector<int32> output_feature_ids;
686     std::vector<int32> output_feature_dimensions;
687     std::vector<int32> output_thresholds;
688     std::vector<Eigen::VectorXf> output_left_node_contribs;
689     std::vector<Eigen::VectorXf> output_right_node_contribs;
690     std::vector<string> output_split_types;
691 
692     // TODO(tanzheny) parallelize the computation.
693     // Iterate each node and find the best gain per node.
694     float parent_gain;
695     for (int32_t node_id = node_id_first; node_id < node_id_last; ++node_id) {
696       float best_gain = std::numeric_limits<float>::lowest();
697       int32_t best_bucket;
698       int32_t best_f_id;
699       int32_t best_f_dim;
700       string best_split_type;
701       Eigen::VectorXf best_contrib_for_left(logits_dim);
702       Eigen::VectorXf best_contrib_for_right(logits_dim);
703 
704       // Sum of gradient and hessian. Compute parent gain using first feature.
705       ConstMatrixMap stats_mat(&stats_summaries[0](node_id, 0, 0, 0),
706                                num_buckets + 1,  // Including default bucket.
707                                logits_dim + hessian_dim);
708       const Eigen::VectorXf total_grad =
709           stats_mat.leftCols(logits_dim).colwise().sum();
710       const Eigen::VectorXf total_hess =
711           stats_mat.rightCols(hessian_dim).colwise().sum();
712       if (total_hess.norm() < min_node_weight) {
713         continue;
714       }
715       Eigen::VectorXf unused(logits_dim);
716       CalculateWeightsAndGains(total_grad, total_hess, l1, l2, &unused,
717                                &parent_gain);
718       for (int f_idx = 0; f_idx < num_features_; ++f_idx) {
719         const string split_type = split_types(f_idx);
720         TTypes<float, 4>::ConstTensor stats_summary = stats_summaries[f_idx];
721         float f_best_gain = std::numeric_limits<float>::lowest();
722         int32_t f_best_bucket;
723         int32_t f_best_f_dim;
724         string f_best_split_type;
725         Eigen::VectorXf f_best_contrib_for_left(logits_dim);
726         Eigen::VectorXf f_best_contrib_for_right(logits_dim);
727 
728         if (split_type == kInequalitySplit) {
729           CalculateBestInequalitySplit(
730               stats_summary, node_id, feature_dims, logits_dim, hessian_dim,
731               num_buckets, min_node_weight, l1, l2, &f_best_gain,
732               &f_best_bucket, &f_best_f_dim, &f_best_split_type,
733               &f_best_contrib_for_left, &f_best_contrib_for_right);
734         } else {
735           CalculateBestEqualitySplit(
736               stats_summary, total_grad, total_hess, node_id, feature_dims,
737               logits_dim, hessian_dim, num_buckets, l1, l2, &f_best_gain,
738               &f_best_bucket, &f_best_f_dim, &f_best_split_type,
739               &f_best_contrib_for_left, &f_best_contrib_for_right);
740         }
741         if (f_best_gain > best_gain) {
742           best_gain = f_best_gain;
743           best_f_id = candidate_feature_ids(f_idx);
744           best_f_dim = f_best_f_dim;
745           best_split_type = f_best_split_type;
746           best_bucket = f_best_bucket;
747           best_contrib_for_left = f_best_contrib_for_left;
748           best_contrib_for_right = f_best_contrib_for_right;
749         }
750       }  // For feature id.
751       if (best_gain == std::numeric_limits<float>::lowest()) {
752         // Do not add the node if no split is found.
753         continue;
754       }
755       output_node_ids.push_back(node_id);
756       // Remove the parent gain for the parent node.
757       output_gains.push_back(best_gain - parent_gain);
758       output_feature_ids.push_back(best_f_id);
759       output_feature_dimensions.push_back(best_f_dim);
760       // Default direction is fixed for dense splits.
761       // TODO(tanzheny) account for default values.
762       output_split_types.push_back(best_split_type);
763       output_thresholds.push_back(best_bucket);
764       output_left_node_contribs.push_back(best_contrib_for_left);
765       output_right_node_contribs.push_back(best_contrib_for_right);
766     }  // for node id.
767     const int num_nodes = output_node_ids.size();
768     // output_node_ids
769     Tensor* output_node_ids_t = nullptr;
770     OP_REQUIRES_OK(context, context->allocate_output("node_ids", {num_nodes},
771                                                      &output_node_ids_t));
772     auto output_node_ids_vec = output_node_ids_t->vec<int32>();
773 
774     // output_gains
775     Tensor* output_gains_t;
776     OP_REQUIRES_OK(context, context->allocate_output("gains", {num_nodes},
777                                                      &output_gains_t));
778     auto output_gains_vec = output_gains_t->vec<float>();
779 
780     // output_feature_ids
781     Tensor* output_features_ids_t;
782     OP_REQUIRES_OK(context, context->allocate_output("feature_ids", {num_nodes},
783                                                      &output_features_ids_t));
784     auto output_features_vec = output_features_ids_t->vec<int32>();
785 
786     // output_feature_dimensions
787     Tensor* output_feature_dimension_t;
788     OP_REQUIRES_OK(context,
789                    context->allocate_output("feature_dimensions", {num_nodes},
790                                             &output_feature_dimension_t));
791     auto output_feature_dimensions_vec =
792         output_feature_dimension_t->vec<int32>();
793 
794     // output_thresholds
795     Tensor* output_thresholds_t;
796     OP_REQUIRES_OK(context, context->allocate_output("thresholds", {num_nodes},
797                                                      &output_thresholds_t));
798     auto output_thresholds_vec = output_thresholds_t->vec<int32>();
799 
800     // output_left_node_contribs
801     Tensor* output_left_node_contribs_t;
802     OP_REQUIRES_OK(context, context->allocate_output(
803                                 "left_node_contribs", {num_nodes, logits_dim},
804                                 &output_left_node_contribs_t));
805     auto output_left_node_contribs_matrix =
806         output_left_node_contribs_t->matrix<float>();
807 
808     // output_right_node_contribs
809     Tensor* output_right_node_contribs_t;
810     OP_REQUIRES_OK(context, context->allocate_output(
811                                 "right_node_contribs", {num_nodes, logits_dim},
812                                 &output_right_node_contribs_t));
813     auto output_right_node_contribs_matrix =
814         output_right_node_contribs_t->matrix<float>();
815 
816     // split type
817     Tensor* output_split_types_t;
818     OP_REQUIRES_OK(
819         context, context->allocate_output("split_with_default_directions",
820                                           {num_nodes}, &output_split_types_t));
821     auto output_split_types_vec = output_split_types_t->vec<tstring>();
822 
823     // Sets output tensors from vectors.
824     for (int i = 0; i < num_nodes; ++i) {
825       output_node_ids_vec(i) = output_node_ids[i];
826       output_features_vec(i) = output_feature_ids[i];
827       // Adjust the gains to penalize by tree complexity.
828       output_gains_vec(i) = output_gains[i] - tree_complexity;
829       output_feature_dimensions_vec(i) = output_feature_dimensions[i];
830       output_thresholds_vec(i) = output_thresholds[i];
831       for (int j = 0; j < logits_dim; ++j) {
832         output_left_node_contribs_matrix(i, j) =
833             output_left_node_contribs[i][j];
834         output_right_node_contribs_matrix(i, j) =
835             output_right_node_contribs[i][j];
836       }
837       output_split_types_vec(i) = output_split_types[i];
838     }
839   }
840 
841  private:
842   // TODO(crawles): Simplify inequality path just like equality b/138329196
843   // Currently this is not simplify-able due to numerical instability in math
844   // i.e. gain = -g.transpose() * hessian_and_reg.colPivHouseholderQr().solve(g)
845   // It caused gain to be Inf when g is approaching 0 but not exactly 0 while
846   // there is no regularization.
847   // Calculate the best inequality split per node.
CalculateBestInequalitySplit(TTypes<float,4>::ConstTensor stats_summary,const int32_t node_id,const int32_t feature_dims,const int32_t logits_dim,const int32_t hessian_dim,const int32_t num_buckets,const float min_node_weight,const float l1,const float l2,float * best_gain,int32 * best_bucket,int32 * best_f_dim,string * best_split_type,Eigen::VectorXf * best_contrib_for_left,Eigen::VectorXf * best_contrib_for_right)848   void CalculateBestInequalitySplit(
849       TTypes<float, 4>::ConstTensor stats_summary, const int32_t node_id,
850       const int32_t feature_dims, const int32_t logits_dim,
851       const int32_t hessian_dim, const int32_t num_buckets,
852       const float min_node_weight, const float l1, const float l2,
853       float* best_gain, int32* best_bucket, int32* best_f_dim,
854       string* best_split_type, Eigen::VectorXf* best_contrib_for_left,
855       Eigen::VectorXf* best_contrib_for_right) {
856     std::vector<Eigen::VectorXf> cum_grad;
857     std::vector<Eigen::VectorXf> cum_hess;
858     // get all cumulative gradients including default bucket.
859     cum_grad.reserve(num_buckets);
860     cum_hess.reserve(num_buckets);
861 
862     for (int f_dim = 0; f_dim < feature_dims; ++f_dim) {
863       ConstVectorMap default_stats_vec(
864           &stats_summary(node_id, f_dim, num_buckets, 0),
865           logits_dim + hessian_dim);
866       Eigen::VectorXf missing_bucket_grad = default_stats_vec.head(logits_dim);
867       Eigen::VectorXf missing_bucket_hess = default_stats_vec.tail(hessian_dim);
868       cum_grad.clear();
869       cum_hess.clear();
870       Eigen::VectorXf total_grad = Eigen::VectorXf::Zero(logits_dim);
871       Eigen::VectorXf total_hess = Eigen::VectorXf::Zero(hessian_dim);
872       // sum all the gradients including default bucket.
873       for (int bucket = 0; bucket <= num_buckets; ++bucket) {
874         for (int i = 0; i < logits_dim; ++i) {
875           total_grad[i] += stats_summary(node_id, f_dim, bucket, i);
876         }
877         for (int i = 0; i < hessian_dim; ++i) {
878           // Full hessian.
879           total_hess[i] +=
880               stats_summary(node_id, f_dim, bucket, logits_dim + i);
881         }
882         if (bucket < num_buckets) {
883           cum_grad.push_back(total_grad);
884           cum_hess.push_back(total_hess);
885         }
886       }
887       const string kInequalityDefaultLeft =
888           boosted_trees::SplitTypeWithDefault_Name(
889               boosted_trees::INEQUALITY_DEFAULT_LEFT);
890       const string kInequalityDefaultRight =
891           boosted_trees::SplitTypeWithDefault_Name(
892               boosted_trees::INEQUALITY_DEFAULT_RIGHT);
893 
894       // Iterate from left to right, excluding default bucket.
895       for (int bucket = 0; bucket < num_buckets; ++bucket) {
896         // default value goes to left node.
897         const Eigen::VectorXf total_left_grad =
898             cum_grad[bucket] + missing_bucket_grad;
899         const Eigen::VectorXf total_left_hess =
900             cum_hess[bucket] + missing_bucket_hess;
901         MaybeUpdateBestSplit(
902             total_left_grad, total_grad - total_left_grad, total_left_hess,
903             total_hess - total_left_hess, logits_dim, bucket, f_dim, l1, l2,
904             kInequalityDefaultLeft, best_gain, best_bucket, best_f_dim,
905             best_split_type, best_contrib_for_left, best_contrib_for_right);
906         // default value goes to right node.
907         MaybeUpdateBestSplit(
908             cum_grad[bucket], total_grad - cum_grad[bucket], cum_hess[bucket],
909             total_hess - cum_hess[bucket], logits_dim, bucket, f_dim, l1, l2,
910             kInequalityDefaultRight, best_gain, best_bucket, best_f_dim,
911             best_split_type, best_contrib_for_left, best_contrib_for_right);
912       }  // for bucket
913     }
914   }
915 
916   // Calculate the best equality split per node.
CalculateBestEqualitySplit(TTypes<float,4>::ConstTensor stats_summary,const Eigen::VectorXf & total_grad,const Eigen::VectorXf & total_hess,const int32_t node_id,const int32_t feature_dims,const int32_t logits_dim,const int32_t hessian_dim,const int32_t num_buckets,const float l1,const float l2,float * best_gain,int32 * best_bucket,int32 * best_f_dim,string * best_split_type,Eigen::VectorXf * best_contrib_for_left,Eigen::VectorXf * best_contrib_for_right)917   void CalculateBestEqualitySplit(
918       TTypes<float, 4>::ConstTensor stats_summary,
919       const Eigen::VectorXf& total_grad, const Eigen::VectorXf& total_hess,
920       const int32_t node_id, const int32_t feature_dims,
921       const int32_t logits_dim, const int32_t hessian_dim,
922       const int32_t num_buckets, const float l1, const float l2,
923       float* best_gain, int32* best_bucket, int32* best_f_dim,
924       string* best_split_type, Eigen::VectorXf* best_contrib_for_left,
925       Eigen::VectorXf* best_contrib_for_right) {
926     const string kEqualityDefaultRight =
927         boosted_trees::SplitTypeWithDefault_Name(
928             boosted_trees::EQUALITY_DEFAULT_RIGHT);
929     for (int f_dim = 0; f_dim < feature_dims; ++f_dim) {
930       for (int bucket = 0; bucket < num_buckets; ++bucket) {
931         ConstVectorMap stats_vec(&stats_summary(node_id, f_dim, bucket, 0),
932                                  logits_dim + hessian_dim);
933         Eigen::VectorXf curr_grad = stats_vec.head(logits_dim);
934         Eigen::VectorXf curr_hess = stats_vec.tail(hessian_dim);
935         MaybeUpdateBestSplit(curr_grad, total_grad - curr_grad, curr_hess,
936                              total_hess - curr_hess, logits_dim, bucket, f_dim,
937                              l1, l2, kEqualityDefaultRight, best_gain,
938                              best_bucket, best_f_dim, best_split_type,
939                              best_contrib_for_left, best_contrib_for_right);
940       }
941     }
942   }
943 
MaybeUpdateBestSplit(const Eigen::VectorXf & grad_for_left,const Eigen::VectorXf & grad_for_right,const Eigen::VectorXf & hess_for_left,const Eigen::VectorXf & hess_for_right,const int32_t logits_dim,const int32_t bucket,const int32_t f_dim,const float l1,const float l2,const string split_type,float * best_gain,int32 * best_bucket,int32 * best_f_dim,string * best_split_type,Eigen::VectorXf * best_contrib_for_left,Eigen::VectorXf * best_contrib_for_right)944   void MaybeUpdateBestSplit(const Eigen::VectorXf& grad_for_left,
945                             const Eigen::VectorXf& grad_for_right,
946                             const Eigen::VectorXf& hess_for_left,
947                             const Eigen::VectorXf& hess_for_right,
948                             const int32_t logits_dim, const int32_t bucket,
949                             const int32_t f_dim, const float l1, const float l2,
950                             const string split_type, float* best_gain,
951                             int32* best_bucket, int32* best_f_dim,
952                             string* best_split_type,
953                             Eigen::VectorXf* best_contrib_for_left,
954                             Eigen::VectorXf* best_contrib_for_right) {
955     // Left child.
956     Eigen::VectorXf contrib_for_left(logits_dim);
957     float gain_for_left;
958     CalculateWeightsAndGains(grad_for_left, hess_for_left, l1, l2,
959                              &contrib_for_left, &gain_for_left);
960     Eigen::VectorXf contrib_for_right(logits_dim);
961     float gain_for_right;
962     CalculateWeightsAndGains(grad_for_right, hess_for_right, l1, l2,
963                              &contrib_for_right, &gain_for_right);
964     if (GainIsLarger(gain_for_left + gain_for_right, *best_gain)) {
965       *best_gain = gain_for_left + gain_for_right;
966       *best_bucket = bucket;
967       *best_f_dim = f_dim;
968       *best_contrib_for_left = contrib_for_left;
969       *best_contrib_for_right = contrib_for_right;
970       *best_split_type = split_type;
971     }
972   }
973   int num_features_;
974   int logits_dim_;
975 };
976 
977 // v2 op that supports multi-class.
978 REGISTER_KERNEL_BUILDER(
979     Name("BoostedTreesCalculateBestFeatureSplitV2").Device(DEVICE_CPU),
980     BoostedTreesCalculateBestFeatureSplitV2);
981 
982 // Map from bucket id to vector of statistics.
983 typedef std::map<int32, std::vector<float>> BucketMap;
984 typedef BucketMap::iterator BucketMapIterator;
985 // Map from feature dimension to BucketMap.
986 typedef std::map<int32, BucketMap> FeatureMap;
987 typedef FeatureMap::iterator FeatureMapIterator;
988 
989 class BoostedTreesSparseCalculateBestFeatureSplitOp : public OpKernel {
990  public:
BoostedTreesSparseCalculateBestFeatureSplitOp(OpKernelConstruction * const context)991   explicit BoostedTreesSparseCalculateBestFeatureSplitOp(
992       OpKernelConstruction* const context)
993       : OpKernel(context) {
994     // TODO(crawles): Using logits_dim_ for multi-class split.
995     OP_REQUIRES_OK(context, context->GetAttr("logits_dimension", &logits_dim_));
996     // TODO(tanzheny): Using this for equality split.
997     OP_REQUIRES_OK(context, context->GetAttr("split_type", &split_type_));
998   }
999 
Compute(OpKernelContext * const context)1000   void Compute(OpKernelContext* const context) override {
1001     // node_id_range
1002     const Tensor* node_id_range_t;
1003     OP_REQUIRES_OK(context, context->input("node_id_range", &node_id_range_t));
1004     const auto node_id_range = node_id_range_t->vec<int32>();
1005     const int32_t node_id_first = node_id_range(0);  // inclusive
1006     const int32_t node_id_last = node_id_range(1);   // exclusive
1007 
1008     const Tensor* stats_summary_indices_t;
1009     OP_REQUIRES_OK(context, context->input("stats_summary_indices",
1010                                            &stats_summary_indices_t));
1011     const auto stats_summary_indices = stats_summary_indices_t->matrix<int32>();
1012     const int32_t num_sparse_entries = stats_summary_indices_t->dim_size(0);
1013 
1014     const Tensor* stats_summary_values_t;
1015     OP_REQUIRES_OK(context, context->input("stats_summary_values",
1016                                            &stats_summary_values_t));
1017     const auto stats_summary_values = stats_summary_values_t->vec<float>();
1018 
1019     const Tensor* stats_summary_shape_t;
1020     OP_REQUIRES_OK(
1021         context, context->input("stats_summary_shape", &stats_summary_shape_t));
1022     const auto stats_summary_shape = stats_summary_shape_t->vec<int32>();
1023     const int32_t num_buckets = stats_summary_shape(2) - 1;
1024     const int32_t stats_dims = stats_summary_shape(3);
1025 
1026     const Tensor* l1_t;
1027     OP_REQUIRES_OK(context, context->input("l1", &l1_t));
1028     const auto l1 = l1_t->scalar<float>()();
1029 
1030     const Tensor* l2_t;
1031     OP_REQUIRES_OK(context, context->input("l2", &l2_t));
1032     const auto l2 = l2_t->scalar<float>()();
1033 
1034     const Tensor* tree_complexity_t;
1035     OP_REQUIRES_OK(context,
1036                    context->input("tree_complexity", &tree_complexity_t));
1037     const auto tree_complexity = tree_complexity_t->scalar<float>()();
1038 
1039     const Tensor* min_node_weight_t;
1040     OP_REQUIRES_OK(context,
1041                    context->input("min_node_weight", &min_node_weight_t));
1042     const auto min_node_weight = min_node_weight_t->scalar<float>()();
1043 
1044     std::vector<int32> output_node_ids;
1045     std::vector<float> output_gains;
1046     std::vector<int32> output_feature_dimensions;
1047     std::vector<int32> output_thresholds;
1048     std::vector<float> output_left_node_contribs;
1049     std::vector<float> output_right_node_contribs;
1050     std::vector<string> output_split_types;
1051 
1052     FeatureMap f_map;
1053 
1054     int32_t previous_node_id = -1;
1055     for (int idx = 0; idx < num_sparse_entries; ++idx) {
1056       int32_t node_id = stats_summary_indices(idx, 0);
1057       if (node_id != previous_node_id) {
1058         process_node(f_map, &output_node_ids, &output_gains,
1059                      &output_feature_dimensions, &output_thresholds,
1060                      &output_left_node_contribs, &output_right_node_contribs,
1061                      &output_split_types, previous_node_id, min_node_weight, l1,
1062                      l2, num_buckets);
1063         f_map.clear();
1064       }
1065       previous_node_id = node_id;
1066       DCHECK_LE(node_id_first, node_id);
1067       DCHECK_LT(node_id, node_id_last);
1068       const int32_t feature_dim = stats_summary_indices(idx, 1);
1069       const int32_t bucket_id = stats_summary_indices(idx, 2);
1070       const int32_t stat_dim = stats_summary_indices(idx, 3);
1071       OP_REQUIRES(context, stat_dim < stats_dims,
1072                   errors::InvalidArgument(
1073                       "Stat dim, the sum of logits dim and hessian dim in "
1074                       "stats_summary_indices, cannot be greater than stats "
1075                       "dims, the last value in stats_summary_shape, which was ",
1076                       stats_dims, ". At index (", idx,
1077                       ", 4), stats_summary_indices contains value ", stat_dim));
1078       std::pair<FeatureMapIterator, bool> const& f_insert_result = f_map.insert(
1079           FeatureMapIterator::value_type(feature_dim, BucketMap()));
1080       auto& b_map = f_insert_result.first->second;
1081       std::pair<BucketMapIterator, bool> const& b_insert_result =
1082           b_map.insert(BucketMapIterator::value_type(
1083               bucket_id, std::vector<float>(stats_dims)));
1084       auto& stats = b_insert_result.first->second;
1085       stats[stat_dim] = stats_summary_values(idx);
1086     }  // for node_id
1087     // process the last node id
1088     process_node(f_map, &output_node_ids, &output_gains,
1089                  &output_feature_dimensions, &output_thresholds,
1090                  &output_left_node_contribs, &output_right_node_contribs,
1091                  &output_split_types, previous_node_id, min_node_weight, l1, l2,
1092                  num_buckets);
1093 
1094     const int num_nodes = output_node_ids.size();
1095     // output_node_ids
1096     Tensor* output_node_ids_t = nullptr;
1097     OP_REQUIRES_OK(context, context->allocate_output("node_ids", {num_nodes},
1098                                                      &output_node_ids_t));
1099     auto output_node_ids_vec = output_node_ids_t->vec<int32>();
1100 
1101     // output_gains
1102     Tensor* output_gains_t;
1103     OP_REQUIRES_OK(context, context->allocate_output("gains", {num_nodes},
1104                                                      &output_gains_t));
1105     auto output_gains_vec = output_gains_t->vec<float>();
1106 
1107     // output_feature_dimensions
1108     Tensor* output_feature_dimension_t;
1109     OP_REQUIRES_OK(context,
1110                    context->allocate_output("feature_dimensions", {num_nodes},
1111                                             &output_feature_dimension_t));
1112     auto output_feature_dimensions_vec =
1113         output_feature_dimension_t->vec<int32>();
1114 
1115     // output_thresholds
1116     Tensor* output_thresholds_t;
1117     OP_REQUIRES_OK(context, context->allocate_output("thresholds", {num_nodes},
1118                                                      &output_thresholds_t));
1119     auto output_thresholds_vec = output_thresholds_t->vec<int32>();
1120 
1121     // output_left_node_contribs
1122     Tensor* output_left_node_contribs_t;
1123     OP_REQUIRES_OK(
1124         context, context->allocate_output("left_node_contribs", {num_nodes, 1},
1125                                           &output_left_node_contribs_t));
1126     auto output_left_node_contribs_matrix =
1127         output_left_node_contribs_t->matrix<float>();
1128 
1129     // output_right_node_contribs
1130     Tensor* output_right_node_contribs_t;
1131     OP_REQUIRES_OK(
1132         context, context->allocate_output("right_node_contribs", {num_nodes, 1},
1133                                           &output_right_node_contribs_t));
1134     auto output_right_node_contribs_matrix =
1135         output_right_node_contribs_t->matrix<float>();
1136 
1137     // split type
1138     Tensor* output_split_types_t;
1139     OP_REQUIRES_OK(
1140         context, context->allocate_output("split_with_default_directions",
1141                                           {num_nodes}, &output_split_types_t));
1142     auto output_split_types_vec = output_split_types_t->vec<tstring>();
1143 
1144     // Sets output tensors from vectors.
1145     for (int i = 0; i < num_nodes; ++i) {
1146       output_node_ids_vec(i) = output_node_ids[i];
1147       // Adjust the gains to penalize by tree complexity.
1148       output_gains_vec(i) = output_gains[i] - tree_complexity;
1149       output_feature_dimensions_vec(i) = output_feature_dimensions[i];
1150       output_thresholds_vec(i) = output_thresholds[i];
1151       // TODO(crawles): change this for multi-class.
1152       output_left_node_contribs_matrix(i, 0) = output_left_node_contribs[i];
1153       output_right_node_contribs_matrix(i, 0) = output_right_node_contribs[i];
1154       output_split_types_vec(i) = output_split_types[i];
1155     }
1156   }
1157 
1158  protected:
process_node(const FeatureMap & f_map,std::vector<int32> * output_node_ids,std::vector<float> * output_gains,std::vector<int32> * output_feature_dimensions,std::vector<int32> * output_thresholds,std::vector<float> * output_left_node_contribs,std::vector<float> * output_right_node_contribs,std::vector<string> * output_split_types,const int32_t node_id,const float min_node_weight,const float l1,const float l2,const int32_t num_buckets)1159   void process_node(const FeatureMap& f_map,
1160                     std::vector<int32>* output_node_ids,
1161                     std::vector<float>* output_gains,
1162                     std::vector<int32>* output_feature_dimensions,
1163                     std::vector<int32>* output_thresholds,
1164                     std::vector<float>* output_left_node_contribs,
1165                     std::vector<float>* output_right_node_contribs,
1166                     std::vector<string>* output_split_types,
1167                     const int32_t node_id, const float min_node_weight,
1168                     const float l1, const float l2, const int32_t num_buckets) {
1169     float parent_gain;
1170     Eigen::VectorXf unused(logits_dim_);
1171     Eigen::MatrixXf identity;
1172     identity.setIdentity(1, 1);
1173 
1174     // start processing for previous node id.
1175     float best_gain = std::numeric_limits<float>::lowest();
1176     float best_bucket = 0;
1177     float best_f_dim = 0;
1178     string best_split_type = boosted_trees::SplitTypeWithDefault_Name(
1179         boosted_trees::INEQUALITY_DEFAULT_LEFT);
1180     float best_contrib_for_left = 0.0;
1181     float best_contrib_for_right = 0.0;
1182     // the sum of gradients including default bucket.
1183     float total_grad = 0;
1184     // the sum of hessians including default bucket.
1185     float total_hess = 0;
1186 
1187     for (auto f_iter = f_map.begin(); f_iter != f_map.end(); ++f_iter) {
1188       const int32_t feature_dim = f_iter->first;
1189       const auto buckets_to_stats_map = f_iter->second;
1190 
1191       // The very last bucket contains stats for missing values.
1192       // TODO(crawles): use vector for multi-class.
1193       const float default_grad =
1194           (buckets_to_stats_map.find(num_buckets) == buckets_to_stats_map.end()
1195                ? 0
1196                : buckets_to_stats_map.at(num_buckets)[0]);
1197       const float default_hess =
1198           (buckets_to_stats_map.find(num_buckets) == buckets_to_stats_map.end()
1199                ? 0
1200                : buckets_to_stats_map.at(num_buckets)[1]);
1201 
1202       if (f_iter == f_map.begin()) {
1203         // first get the sum of grads, including default bucket.
1204         for (auto b_iter = buckets_to_stats_map.begin();
1205              b_iter != buckets_to_stats_map.end(); ++b_iter) {
1206           total_grad += b_iter->second[0];
1207           total_hess += b_iter->second[1];
1208         }
1209         if (total_hess < min_node_weight) {
1210           // Do not split the node because not enough avg hessian.
1211           break;
1212         }
1213         CalculateWeightsAndGains(total_grad * identity, total_hess * identity,
1214                                  l1, l2, &unused, &parent_gain);
1215       }
1216 
1217       float total_left_grad = 0;
1218       float total_left_hess = 0;
1219       for (auto b_iter = buckets_to_stats_map.begin();
1220            b_iter != buckets_to_stats_map.end(); ++b_iter) {
1221         const int32_t bucket_id = b_iter->first;
1222         // total_left_stats should exclude stats from default bucket.
1223         if (bucket_id == num_buckets) {
1224           break;
1225         }
1226         // TODO(crawles): vector for multi-class.
1227         total_left_grad += b_iter->second[0];
1228         total_left_hess += b_iter->second[1];
1229         // From left to right, default right.
1230         // Left child.
1231         Eigen::VectorXf contrib_for_left(1);
1232         float gain_for_left;
1233         CalculateWeightsAndGains(total_left_grad * identity,
1234                                  total_left_hess * identity, l1, l2,
1235                                  &contrib_for_left, &gain_for_left);
1236         // Right child.
1237         Eigen::VectorXf contrib_for_right(1);
1238         float gain_for_right;
1239         CalculateWeightsAndGains((total_grad - total_left_grad) * identity,
1240                                  (total_hess - total_left_hess) * identity, l1,
1241                                  l2, &contrib_for_right, &gain_for_right);
1242         if (GainIsLarger(gain_for_left + gain_for_right, best_gain)) {
1243           best_gain = gain_for_left + gain_for_right;
1244           best_bucket = bucket_id;
1245           best_f_dim = feature_dim;
1246           best_split_type = boosted_trees::SplitTypeWithDefault_Name(
1247               boosted_trees::INEQUALITY_DEFAULT_RIGHT);
1248           best_contrib_for_left = contrib_for_left[0];
1249           best_contrib_for_right = contrib_for_right[0];
1250         }
1251 
1252         // From right to left, default left.
1253         CalculateWeightsAndGains((total_left_grad + default_grad) * identity,
1254                                  (total_left_hess + default_hess) * identity,
1255                                  l1, l2, &contrib_for_left, &gain_for_left);
1256         CalculateWeightsAndGains(
1257             (total_grad - default_grad - total_left_grad) * identity,
1258             (total_hess - default_hess - total_left_hess) * identity, l1, l2,
1259             &contrib_for_right, &gain_for_right);
1260         if (GainIsLarger(gain_for_left + gain_for_right, best_gain)) {
1261           best_gain = gain_for_left + gain_for_right;
1262           best_bucket = bucket_id;
1263           best_f_dim = feature_dim;
1264           best_split_type = boosted_trees::SplitTypeWithDefault_Name(
1265               boosted_trees::INEQUALITY_DEFAULT_LEFT);
1266           best_contrib_for_left = contrib_for_left[0];
1267           best_contrib_for_right = contrib_for_right[0];
1268         }
1269       }  // for bucket_id
1270     }    // for feature_dim
1271     if (best_gain != std::numeric_limits<float>::lowest()) {
1272       output_node_ids->push_back(node_id);
1273       // Remove the parent gain.
1274       output_gains->push_back(best_gain - parent_gain);
1275       output_feature_dimensions->push_back(best_f_dim);
1276       output_split_types->push_back(best_split_type);
1277       output_thresholds->push_back(best_bucket);
1278       output_left_node_contribs->push_back(best_contrib_for_left);
1279       output_right_node_contribs->push_back(best_contrib_for_right);
1280     }
1281   }
1282 
1283  private:
1284   int logits_dim_;
1285   string split_type_;
1286 };
1287 
1288 REGISTER_KERNEL_BUILDER(
1289     Name("BoostedTreesSparseCalculateBestFeatureSplit").Device(DEVICE_CPU),
1290     BoostedTreesSparseCalculateBestFeatureSplitOp);
1291 
1292 class BoostedTreesMakeStatsSummaryOp : public OpKernel {
1293  public:
BoostedTreesMakeStatsSummaryOp(OpKernelConstruction * const context)1294   explicit BoostedTreesMakeStatsSummaryOp(OpKernelConstruction* const context)
1295       : OpKernel(context) {
1296     OP_REQUIRES_OK(context, context->GetAttr("max_splits", &max_splits_));
1297     OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
1298     OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_));
1299   }
1300 
Compute(OpKernelContext * const context)1301   void Compute(OpKernelContext* const context) override {
1302     // node_ids
1303     const Tensor* node_ids_t;
1304     OP_REQUIRES_OK(context, context->input("node_ids", &node_ids_t));
1305     const auto node_ids = node_ids_t->vec<int32>();
1306     // gradients
1307     const Tensor* gradients_t;
1308     OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
1309     const auto gradients = gradients_t->matrix<float>();
1310     // hessians
1311     const Tensor* hessians_t;
1312     OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
1313     const auto hessians = hessians_t->matrix<float>();
1314     // bucketized_features
1315     OpInputList bucketized_features_list;
1316     OP_REQUIRES_OK(context, context->input_list("bucketized_features_list",
1317                                                 &bucketized_features_list));
1318     // Infer batch size.
1319     const int64_t batch_size = node_ids_t->dim_size(0);
1320 
1321     // Allocate temporary stats tensor (Rank 4).
1322     Tensor temp_stats_double_t;
1323     OP_REQUIRES_OK(context, context->allocate_temp(
1324                                 DT_DOUBLE,
1325                                 {num_features_, max_splits_, num_buckets_, 2},
1326                                 &temp_stats_double_t));
1327     auto temp_stats_double = temp_stats_double_t.tensor<double, 4>();
1328     temp_stats_double.setZero();
1329 
1330     // Partition by node, and then bucketize.
1331     for (int feature_idx = 0; feature_idx < num_features_; ++feature_idx) {
1332       const auto& features = bucketized_features_list[feature_idx].vec<int32>();
1333       for (int i = 0; i < batch_size; ++i) {
1334         const int32_t node = node_ids(i);
1335         const int32_t bucket = features(i);
1336         temp_stats_double(feature_idx, node, bucket, 0) += gradients(i, 0);
1337         temp_stats_double(feature_idx, node, bucket, 1) += hessians(i, 0);
1338       }
1339     }
1340 
1341     // Copy temp tensor over to output tensor.
1342     Tensor* output_stats_summary_t = nullptr;
1343     OP_REQUIRES_OK(context, context->allocate_output(
1344                                 "stats_summary", temp_stats_double_t.shape(),
1345                                 &output_stats_summary_t));
1346     output_stats_summary_t->tensor<float, 4>() =
1347         temp_stats_double.template cast<float>();
1348   }
1349 
1350  private:
1351   int max_splits_;
1352   int num_buckets_;
1353   int num_features_;
1354 };
1355 
1356 REGISTER_KERNEL_BUILDER(Name("BoostedTreesMakeStatsSummary").Device(DEVICE_CPU),
1357                         BoostedTreesMakeStatsSummaryOp);
1358 
1359 // TODO(tanzheny): Add an option of default value into the API interface.
1360 class BoostedTreesAggregateStatsOp : public OpKernel {
1361  public:
BoostedTreesAggregateStatsOp(OpKernelConstruction * const context)1362   explicit BoostedTreesAggregateStatsOp(OpKernelConstruction* const context)
1363       : OpKernel(context) {
1364     OP_REQUIRES_OK(context, context->GetAttr("max_splits", &max_splits_));
1365     OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
1366   }
1367 
Compute(OpKernelContext * const context)1368   void Compute(OpKernelContext* const context) override {
1369     // node_ids.
1370     const Tensor* node_ids_t;
1371     OP_REQUIRES_OK(context, context->input("node_ids", &node_ids_t));
1372     const auto node_ids = node_ids_t->vec<int32>();
1373 
1374     // gradients.
1375     const Tensor* gradients_t;
1376     OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
1377     const auto gradients = gradients_t->matrix<float>();
1378 
1379     // hessians.
1380     const Tensor* hessians_t;
1381     OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
1382     const auto hessians = hessians_t->matrix<float>();
1383 
1384     // feature.
1385     const Tensor* feature_t;
1386     OP_REQUIRES_OK(context, context->input("feature", &feature_t));
1387     const auto feature = feature_t->matrix<int32>();
1388 
1389     // Infer batch size, feature dimension and stats dimension.
1390     const int64_t batch_size = node_ids_t->dim_size(0);
1391     const int64_t logits_dims = gradients_t->dim_size(1);
1392     const int64_t hessians_dims = hessians_t->dim_size(1);
1393     const int64_t stats_dims = logits_dims + hessians_dims;
1394     const int64_t feature_dims = feature_t->dim_size(1);
1395 
1396     // Allocate temporary stats tensor (Rank 4), upcasting to double.
1397     // A default bucket is added to the end for missing/default values.
1398     Tensor temp_stats_double_t;
1399     OP_REQUIRES_OK(
1400         context, context->allocate_temp(
1401                      DT_DOUBLE,
1402                      {max_splits_, feature_dims, num_buckets_ + 1, stats_dims},
1403                      &temp_stats_double_t));
1404     auto temp_stats_double = temp_stats_double_t.tensor<double, 4>();
1405     temp_stats_double.setZero();
1406 
1407     for (int i = 0; i < batch_size; ++i) {
1408       const int32_t node = node_ids(i);
1409       for (int feature_dim = 0; feature_dim < feature_dims; ++feature_dim) {
1410         const int32_t feature_value = feature(i, feature_dim);
1411         const int32_t bucket =
1412             (feature_value == -1) ? num_buckets_ : feature_value;
1413         for (int stat_dim = 0; stat_dim < logits_dims; ++stat_dim) {
1414           temp_stats_double(node, feature_dim, bucket, stat_dim) +=
1415               gradients(i, stat_dim);
1416         }
1417         for (int stat_dim = logits_dims; stat_dim < stats_dims; ++stat_dim) {
1418           temp_stats_double(node, feature_dim, bucket, stat_dim) +=
1419               hessians(i, stat_dim - logits_dims);
1420         }
1421       }
1422     }
1423 
1424     // Copy temp tensor over to output tensor, downcasting to float.
1425     Tensor* output_stats_summary_t = nullptr;
1426     OP_REQUIRES_OK(context, context->allocate_output(
1427                                 "stats_summary", temp_stats_double_t.shape(),
1428                                 &output_stats_summary_t));
1429     output_stats_summary_t->tensor<float, 4>() =
1430         temp_stats_double.template cast<float>();
1431   }
1432 
1433  private:
1434   int max_splits_;
1435   int num_buckets_;
1436 };
1437 
1438 REGISTER_KERNEL_BUILDER(Name("BoostedTreesAggregateStats").Device(DEVICE_CPU),
1439                         BoostedTreesAggregateStatsOp);
1440 
1441 // Key based on node id, feature dimension and bucket id.
1442 struct StatsPartitionKey {
StatsPartitionKeytensorflow::StatsPartitionKey1443   StatsPartitionKey(const int32_t node_id, const int32_t feature_dim,
1444                     const int32_t bucket_id)
1445       : node_id(node_id), feature_dim(feature_dim), bucket_id(bucket_id) {}
1446 
operator ==tensorflow::StatsPartitionKey1447   bool operator==(const StatsPartitionKey& other) const {
1448     return (node_id == other.node_id) && (feature_dim == other.feature_dim) &&
1449            (bucket_id == other.bucket_id);
1450   }
1451 
1452   // Compare for StatsPartitionKey.
1453   struct Less {
operator ()tensorflow::StatsPartitionKey::Less1454     bool operator()(const StatsPartitionKey& a,
1455                     const StatsPartitionKey& b) const {
1456       if (a.node_id < b.node_id) {
1457         return true;
1458       }
1459       if ((a.node_id == b.node_id) && (a.feature_dim < b.feature_dim)) {
1460         return true;
1461       }
1462       if ((a.node_id == b.node_id) && (a.feature_dim == b.feature_dim) &&
1463           (a.bucket_id < b.bucket_id)) {
1464         return true;
1465       }
1466       return false;
1467     }
1468   };
1469 
1470   // Tree node id.
1471   int32 node_id;
1472   // Dimension within feature column.
1473   int32 feature_dim;
1474   // bucketized feature value .
1475   int32 bucket_id;
1476 };
1477 
1478 typedef std::map<StatsPartitionKey, std::vector<float>, StatsPartitionKey::Less>
1479     StatsPartitionMap;
1480 typedef StatsPartitionMap::iterator StatsPartitionIterator;
1481 
1482 // Key based on instance and feature dimension.
1483 struct InstanceFeatureDimKey {
InstanceFeatureDimKeytensorflow::InstanceFeatureDimKey1484   InstanceFeatureDimKey() : instance(-1), feature_dim(-1) {}
1485 
InstanceFeatureDimKeytensorflow::InstanceFeatureDimKey1486   InstanceFeatureDimKey(const int32_t instance, const int32_t feature_dim)
1487       : instance(instance), feature_dim(feature_dim) {}
1488 
operator ==tensorflow::InstanceFeatureDimKey1489   bool operator==(const InstanceFeatureDimKey& other) const {
1490     return (instance == other.instance) && (feature_dim == other.feature_dim);
1491   }
1492 
1493   // Compare for InstanceFeatureDimKey.
1494   struct Less {
operator ()tensorflow::InstanceFeatureDimKey::Less1495     bool operator()(const InstanceFeatureDimKey& a,
1496                     const InstanceFeatureDimKey& b) const {
1497       if (a.instance < b.instance) {
1498         return true;
1499       }
1500       if ((a.instance == b.instance) && (a.feature_dim < b.feature_dim)) {
1501         return true;
1502       }
1503       return false;
1504     }
1505   };
1506 
1507   // Instance id within a batch.
1508   int32 instance;
1509   // Dimension within feature column.
1510   int32 feature_dim;
1511 };
1512 
1513 // Add statistics to StatsPartitionMap for (instance, feature dim, bucket id).
AddInstanceStatsToMap(const int32_t instance,const int32_t feature_dim,const int32_t bucket_id,const int32_t logits_dims,const int32_t stats_dims,StatsPartitionMap * stats_map,const TTypes<float>::ConstMatrix & gradients,const TTypes<float>::ConstMatrix & hessians,const TTypes<int32>::ConstVec & node_ids)1514 static void AddInstanceStatsToMap(
1515     const int32_t instance, const int32_t feature_dim, const int32_t bucket_id,
1516     const int32_t logits_dims, const int32_t stats_dims,
1517     StatsPartitionMap* stats_map, const TTypes<float>::ConstMatrix& gradients,
1518     const TTypes<float>::ConstMatrix& hessians,
1519     const TTypes<int32>::ConstVec& node_ids) {
1520   const int32_t node_id = node_ids(instance);
1521   const auto key = StatsPartitionKey(node_id, feature_dim, bucket_id);
1522   std::pair<StatsPartitionIterator, bool> const& insert_result =
1523       stats_map->insert(StatsPartitionIterator::value_type(
1524           key, std::vector<float>(stats_dims, 0.0f)));
1525   auto& stats = insert_result.first->second;
1526   for (int stat_dim = 0; stat_dim < logits_dims; ++stat_dim) {
1527     stats[stat_dim] += gradients(instance, stat_dim);
1528   }
1529   for (int stat_dim = logits_dims; stat_dim < stats_dims; ++stat_dim) {
1530     stats[stat_dim] += hessians(instance, stat_dim - logits_dims);
1531   }
1532 }
1533 
1534 // Add statistics to StatsPartitionMap for bucket_id ranging from
1535 // (start_instance, start_feature_dim) to (end_instance, end_feature_dim),
1536 // inclusive on start and end instances, exclusive on end feature dim.
AddRangeStats(const int start_instance,const int end_instance,const int start_feature_dim,const int end_feature_dim,StatsPartitionMap * stats_map,const TTypes<float>::ConstMatrix & gradients,const TTypes<float>::ConstMatrix & hessians,const TTypes<int32>::ConstVec & node_ids,const int32_t feature_dims,const int32_t bucket_id,const int32_t logits_dims,const int32_t stats_dims)1537 static void AddRangeStats(const int start_instance, const int end_instance,
1538                           const int start_feature_dim,
1539                           const int end_feature_dim,
1540                           StatsPartitionMap* stats_map,
1541                           const TTypes<float>::ConstMatrix& gradients,
1542                           const TTypes<float>::ConstMatrix& hessians,
1543                           const TTypes<int32>::ConstVec& node_ids,
1544                           const int32_t feature_dims, const int32_t bucket_id,
1545                           const int32_t logits_dims, const int32_t stats_dims) {
1546   DCHECK_LE(start_instance, end_instance);
1547   if (start_instance == end_instance) {
1548     DCHECK_LT(start_feature_dim, end_feature_dim);
1549   }
1550   for (int32_t instance = start_instance; instance <= end_instance;
1551        ++instance) {
1552     const int32_t start_f_dim =
1553         (instance == start_instance) ? start_feature_dim + 1 : 0;
1554     const int32_t end_f_dim =
1555         (instance == end_instance) ? end_feature_dim : feature_dims;
1556     for (int32_t f_dim = start_f_dim; f_dim < end_f_dim; ++f_dim) {
1557       AddInstanceStatsToMap(instance, f_dim, bucket_id, logits_dims, stats_dims,
1558                             stats_map, gradients, hessians, node_ids);
1559     }
1560   }
1561 }
1562 
1563 class BoostedTreesSparseAggregateStatsOp : public OpKernel {
1564  public:
BoostedTreesSparseAggregateStatsOp(OpKernelConstruction * const context)1565   explicit BoostedTreesSparseAggregateStatsOp(
1566       OpKernelConstruction* const context)
1567       : OpKernel(context) {
1568     OP_REQUIRES_OK(context, context->GetAttr("max_splits", &max_splits_));
1569     OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
1570   }
1571 
Compute(OpKernelContext * const context)1572   void Compute(OpKernelContext* const context) override {
1573     // node_ids.
1574     const Tensor* node_ids_t;
1575     OP_REQUIRES_OK(context, context->input("node_ids", &node_ids_t));
1576     const auto node_ids = node_ids_t->vec<int32>();
1577 
1578     // gradients.
1579     const Tensor* gradients_t;
1580     OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
1581     const auto gradients = gradients_t->matrix<float>();
1582 
1583     // hessians.
1584     const Tensor* hessians_t;
1585     OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
1586     const auto hessians = hessians_t->matrix<float>();
1587 
1588     // feature indices.
1589     const Tensor* feature_indices_t;
1590     OP_REQUIRES_OK(context,
1591                    context->input("feature_indices", &feature_indices_t));
1592     const auto feature_indices = feature_indices_t->matrix<int32>();
1593 
1594     // feature values.
1595     const Tensor* feature_values_t;
1596     OP_REQUIRES_OK(context,
1597                    context->input("feature_values", &feature_values_t));
1598     const auto feature_values = feature_values_t->vec<int32>();
1599 
1600     // feature shape.
1601     const Tensor* feature_shape_t;
1602     OP_REQUIRES_OK(context, context->input("feature_shape", &feature_shape_t));
1603     OP_REQUIRES(context, TensorShapeUtils::IsVector(feature_shape_t->shape()),
1604                 errors::InvalidArgument(
1605                     "Input shapes should be a vector but received shapes ",
1606                     feature_shape_t->shape().DebugString()));
1607     const auto feature_shape = feature_shape_t->vec<int32>();
1608 
1609     const int64_t batch_size = gradients_t->dim_size(0);
1610     const int64_t logits_dims = gradients_t->dim_size(1);
1611     const int64_t hessians_dims = hessians_t->dim_size(1);
1612     const int64_t stats_dims = logits_dims + hessians_dims;
1613     const int64_t num_sparse_entries = feature_indices_t->dim_size(0);
1614     const int32_t feature_dims = feature_shape(1);
1615     DCHECK_LE(num_sparse_entries, batch_size * feature_dims);
1616 
1617     // Aggregate statistics info to map.
1618     StatsPartitionMap stats_map;
1619 
1620     int prev_instance = 0;
1621     int prev_f_dim = -1;
1622 
1623     for (int i = 0; i < num_sparse_entries; ++i) {
1624       // the instance number within a batch
1625       const int32_t instance = feature_indices(i, 0);
1626       DCHECK_LE(instance, batch_size);
1627       DCHECK_GE(instance, prev_instance);
1628       // the node id within a tree.
1629       const int32_t node_id = node_ids(instance);
1630       DCHECK_LE(node_id, max_splits_);
1631       // the feature dimension.
1632       const int32_t f_dim = feature_indices(i, 1);
1633       DCHECK_LE(f_dim, feature_dims);
1634       // the bucket id of the value.
1635       const int32_t bucket_id = feature_values(i);
1636       DCHECK_LE(bucket_id, num_buckets_);
1637 
1638       // Add statistics for the missing entries into default bucket.
1639       // The last bucket is default bucket.
1640       const int missing_entry_bucket = num_buckets_;
1641       AddRangeStats(prev_instance, instance, prev_f_dim, f_dim, &stats_map,
1642                     gradients, hessians, node_ids, feature_dims,
1643                     missing_entry_bucket, logits_dims, stats_dims);
1644       prev_instance = instance;
1645       prev_f_dim = f_dim;
1646       // Add statistics for the non-missing entry into
1647       // (cur_instance, cur_f_dim, bucket_id).
1648       AddInstanceStatsToMap(instance, f_dim, bucket_id, logits_dims, stats_dims,
1649                             &stats_map, gradients, hessians, node_ids);
1650     }
1651     AddRangeStats(prev_instance, batch_size - 1, prev_f_dim, feature_dims,
1652                   &stats_map, gradients, hessians, node_ids, feature_dims,
1653                   num_buckets_, logits_dims, stats_dims);
1654 
1655     // Serialize statistics info map to tensor output.
1656     const int64_t num_slots = stats_map.size() * stats_dims;
1657     Tensor* summary_indices_t = nullptr;
1658     OP_REQUIRES_OK(context,
1659                    context->allocate_output("stats_summary_indices",
1660                                             TensorShape({num_slots, 4}),
1661                                             &summary_indices_t));
1662     auto summary_indices = summary_indices_t->matrix<int32>();
1663     Tensor* summary_values_t = nullptr;
1664     OP_REQUIRES_OK(context, context->allocate_output("stats_summary_values",
1665                                                      TensorShape({num_slots}),
1666                                                      &summary_values_t));
1667     auto summary_values = summary_values_t->vec<float>();
1668     int entry_index = 0;
1669     for (auto& iter : stats_map) {
1670       for (int stat_dim = 0; stat_dim < stats_dims; ++stat_dim) {
1671         summary_indices(entry_index, 0) = iter.first.node_id;
1672         summary_indices(entry_index, 1) = iter.first.feature_dim;
1673         summary_indices(entry_index, 2) = iter.first.bucket_id;
1674         summary_indices(entry_index, 3) = stat_dim;
1675         summary_values(entry_index) = iter.second[stat_dim];
1676         ++entry_index;
1677       }
1678     }
1679 
1680     Tensor* summary_shape_t = nullptr;
1681     OP_REQUIRES_OK(
1682         context, context->allocate_output("stats_summary_shape",
1683                                           TensorShape({4}), &summary_shape_t));
1684     auto summary_shape = summary_shape_t->vec<int32>();
1685     summary_shape(0) = max_splits_;
1686     summary_shape(1) = feature_dims;
1687     summary_shape(2) = num_buckets_ + 1;
1688     summary_shape(3) = stats_dims;
1689   }
1690 
1691  private:
1692   int max_splits_;
1693   int num_buckets_;
1694 };
1695 
1696 REGISTER_KERNEL_BUILDER(
1697     Name("BoostedTreesSparseAggregateStats").Device(DEVICE_CPU),
1698     BoostedTreesSparseAggregateStatsOp);
1699 
1700 }  // namespace tensorflow
1701