• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <limits>
17 #include <vector>
18 
19 #include "third_party/eigen3/Eigen/Core"
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/framework/tensor_shape.h"
23 #include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
24 #include "tensorflow/core/kernels/boosted_trees/tree_helper.h"
25 #include "tensorflow/core/platform/logging.h"
26 
27 namespace tensorflow {
28 
29 using Matrix =
30     Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
31 using ConstMatrixMap = Eigen::Map<const Matrix>;
32 using MatrixMap = Eigen::Map<Matrix>;
33 
34 using ConstVectorMap = Eigen::Map<const Eigen::VectorXf>;
35 using VectorMap = Eigen::Map<Eigen::VectorXf>;
36 
37 constexpr char kInequalitySplit[] = "inequality";
38 constexpr char kEqualitySplit[] = "equality";
39 
40 // V1 Op. Deprecated. BoostedTreesCalculateBestFeatureSplitOpV2 is V2.
41 class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel {
42  public:
BoostedTreesCalculateBestGainsPerFeatureOp(OpKernelConstruction * const context)43   explicit BoostedTreesCalculateBestGainsPerFeatureOp(
44       OpKernelConstruction* const context)
45       : OpKernel(context) {
46     OP_REQUIRES_OK(context, context->GetAttr("max_splits", &max_splits_));
47     OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_));
48   }
49 
Compute(OpKernelContext * const context)50   void Compute(OpKernelContext* const context) override {
51     // node_id_range
52     const Tensor* node_id_range_t;
53     OP_REQUIRES_OK(context, context->input("node_id_range", &node_id_range_t));
54     const auto node_id_range = node_id_range_t->vec<int32>();
55     const int32 node_id_first = node_id_range(0);  // inclusive
56     const int32 node_id_last = node_id_range(1);   // exclusive
57     // stats_summary_list
58     OpInputList stats_summary_list;
59     OP_REQUIRES_OK(context, context->input_list("stats_summary_list",
60                                                 &stats_summary_list));
61     const int64 num_buckets = stats_summary_list[0].dim_size(1);
62     // Check for single logit: 1 gradient + 1 hessian value.
63     DCHECK_EQ(stats_summary_list[0].dim_size(2), 2);
64     std::vector<TTypes<float, 3>::ConstTensor> stats_summary;
65     stats_summary.reserve(stats_summary_list.size());
66     for (const auto& tensor : stats_summary_list) {
67       stats_summary.emplace_back(tensor.tensor<float, 3>());
68     }
69     const Tensor* l1_t;
70     OP_REQUIRES_OK(context, context->input("l1", &l1_t));
71     const auto l1 = l1_t->scalar<float>()();
72     const Tensor* l2_t;
73     OP_REQUIRES_OK(context, context->input("l2", &l2_t));
74     const auto l2 = l2_t->scalar<float>()();
75     const Tensor* tree_complexity_t;
76     OP_REQUIRES_OK(context,
77                    context->input("tree_complexity", &tree_complexity_t));
78     const auto tree_complexity = tree_complexity_t->scalar<float>()();
79     const Tensor* min_node_weight_t;
80     OP_REQUIRES_OK(context,
81                    context->input("min_node_weight", &min_node_weight_t));
82     const auto min_node_weight = min_node_weight_t->scalar<float>()();
83 
84     // Allocate output lists of tensors:
85     OpOutputList output_node_ids_list;
86     OP_REQUIRES_OK(
87         context, context->output_list("node_ids_list", &output_node_ids_list));
88     OpOutputList output_gains_list;
89     OP_REQUIRES_OK(context,
90                    context->output_list("gains_list", &output_gains_list));
91     OpOutputList output_thresholds_list;
92     OP_REQUIRES_OK(context, context->output_list("thresholds_list",
93                                                  &output_thresholds_list));
94     OpOutputList output_left_node_contribs_list;
95     OP_REQUIRES_OK(context,
96                    context->output_list("left_node_contribs_list",
97                                         &output_left_node_contribs_list));
98     OpOutputList output_right_node_contribs_list;
99     OP_REQUIRES_OK(context,
100                    context->output_list("right_node_contribs_list",
101                                         &output_right_node_contribs_list));
102 
103     // Use identity later to convert float to Eigen::Matrix type for input to
104     // CalculateWeightsAndGains. This op only supports single dimension logits.
105     Eigen::MatrixXf identity;
106     identity.setIdentity(1, 1);
107     // Get the best split info per node for each feature.
108     for (int feature_idx = 0; feature_idx < num_features_; ++feature_idx) {
109       std::vector<float> cum_grad;
110       std::vector<float> cum_hess;
111       cum_grad.reserve(num_buckets);
112       cum_hess.reserve(num_buckets);
113 
114       std::vector<int32> output_node_ids;
115       std::vector<float> output_gains;
116       std::vector<int32> output_thresholds;
117       std::vector<float> output_left_node_contribs;
118       std::vector<float> output_right_node_contribs;
119       for (int node_id = node_id_first; node_id < node_id_last; ++node_id) {
120         // Calculate gains.
121         cum_grad.clear();
122         cum_hess.clear();
123         float total_grad = 0.0;
124         float total_hess = 0.0;
125         for (int bucket = 0; bucket < num_buckets; ++bucket) {
126           // TODO(nponomareva): Consider multi-dimensional gradients/hessians.
127           total_grad += stats_summary[feature_idx](node_id, bucket, 0);
128           total_hess += stats_summary[feature_idx](node_id, bucket, 1);
129           cum_grad.push_back(total_grad);
130           cum_hess.push_back(total_hess);
131         }
132         // Check if node has enough of average hessian.
133         if (total_hess < min_node_weight) {
134           // Do not split the node because not enough avg hessian.
135           continue;
136         }
137         float best_gain = std::numeric_limits<float>::lowest();
138         float best_bucket = 0;
139         float best_contrib_for_left = 0.0;
140         float best_contrib_for_right = 0.0;
141         // Parent gain.
142         float parent_gain;
143         Eigen::VectorXf unused(1);
144         CalculateWeightsAndGains(total_grad * identity, total_hess * identity,
145                                  l1, l2, &unused, &parent_gain);
146 
147         for (int bucket = 0; bucket < num_buckets; ++bucket) {
148           const float cum_grad_bucket = cum_grad[bucket];
149           const float cum_hess_bucket = cum_hess[bucket];
150           // Left child.
151           Eigen::VectorXf contrib_for_left(1);
152           float gain_for_left;
153           CalculateWeightsAndGains(cum_grad_bucket * identity,
154                                    cum_hess_bucket * identity, l1, l2,
155                                    &contrib_for_left, &gain_for_left);
156           // Right child.
157           // use contrib_for_right.
158           Eigen::VectorXf contrib_for_right(1);
159           float gain_for_right;
160           CalculateWeightsAndGains((total_grad - cum_grad_bucket) * identity,
161                                    (total_hess - cum_hess_bucket) * identity,
162                                    l1, l2, &contrib_for_right, &gain_for_right);
163 
164           if (GainIsLarger(gain_for_left + gain_for_right, best_gain)) {
165             best_gain = gain_for_left + gain_for_right;
166             best_bucket = bucket;
167             best_contrib_for_left = contrib_for_left[0];
168             best_contrib_for_right = contrib_for_right[0];
169           }
170         }  // for bucket
171         output_node_ids.push_back(node_id);
172         // Remove the parent gain for the parent node.
173         output_gains.push_back(best_gain - parent_gain);
174         output_thresholds.push_back(best_bucket);
175         output_left_node_contribs.push_back(best_contrib_for_left);
176         output_right_node_contribs.push_back(best_contrib_for_right);
177       }  // for node_id
178       const int num_nodes = output_node_ids.size();
179       // output_node_ids
180       Tensor* output_node_ids_t;
181       OP_REQUIRES_OK(context,
182                      output_node_ids_list.allocate(feature_idx, {num_nodes},
183                                                    &output_node_ids_t));
184       auto output_node_ids_vec = output_node_ids_t->vec<int32>();
185       // output_gains
186       Tensor* output_gains_t;
187       OP_REQUIRES_OK(context, output_gains_list.allocate(
188                                   feature_idx, {num_nodes}, &output_gains_t));
189       auto output_gains_vec = output_gains_t->vec<float>();
190       // output_thresholds
191       Tensor* output_thresholds_t;
192       OP_REQUIRES_OK(context,
193                      output_thresholds_list.allocate(feature_idx, {num_nodes},
194                                                      &output_thresholds_t));
195       auto output_thresholds_vec = output_thresholds_t->vec<int32>();
196       // output_left_node_contribs
197       Tensor* output_left_node_contribs_t;
198       OP_REQUIRES_OK(context, output_left_node_contribs_list.allocate(
199                                   feature_idx, {num_nodes, 1},
200                                   &output_left_node_contribs_t));
201       auto output_left_node_contribs_matrix =
202           output_left_node_contribs_t->matrix<float>();
203       // output_right_node_contribs
204       Tensor* output_right_node_contribs_t;
205       OP_REQUIRES_OK(context, output_right_node_contribs_list.allocate(
206                                   feature_idx, {num_nodes, 1},
207                                   &output_right_node_contribs_t));
208       auto output_right_node_contribs_matrix =
209           output_right_node_contribs_t->matrix<float>();
210       // Sets output tensors from vectors.
211       for (int i = 0; i < num_nodes; ++i) {
212         output_node_ids_vec(i) = output_node_ids[i];
213         // Adjust the gains to penalize by tree complexity.
214         output_gains_vec(i) = output_gains[i] - tree_complexity;
215         output_thresholds_vec(i) = output_thresholds[i];
216         output_left_node_contribs_matrix(i, 0) = output_left_node_contribs[i];
217         // This op only supports 1-dimensional logits.
218         output_right_node_contribs_matrix(i, 0) = output_right_node_contribs[i];
219       }
220     }  // for f
221   }
222 
223  private:
224   int max_splits_;
225   int num_features_;
226 };
227 
228 // V1 op that only supports single dimensional logit.
229 REGISTER_KERNEL_BUILDER(
230     Name("BoostedTreesCalculateBestGainsPerFeature").Device(DEVICE_CPU),
231     BoostedTreesCalculateBestGainsPerFeatureOp);
232 
233 // Deprecated op. Use BoostedTreesCalculateBestFeatureSplitOpV2.
234 class BoostedTreesCalculateBestFeatureSplitOp : public OpKernel {
235  public:
BoostedTreesCalculateBestFeatureSplitOp(OpKernelConstruction * const context)236   explicit BoostedTreesCalculateBestFeatureSplitOp(
237       OpKernelConstruction* const context)
238       : OpKernel(context) {
239     OP_REQUIRES_OK(context, context->GetAttr("logits_dimension", &logits_dim_));
240     OP_REQUIRES_OK(context, context->GetAttr("split_type", &split_type_));
241   }
242 
Compute(OpKernelContext * const context)243   void Compute(OpKernelContext* const context) override {
244     // node_id_range
245     const Tensor* node_id_range_t;
246     OP_REQUIRES_OK(context, context->input("node_id_range", &node_id_range_t));
247     const auto node_id_range = node_id_range_t->vec<int32>();
248     const int32 node_id_first = node_id_range(0);  // inclusive
249     const int32 node_id_last = node_id_range(1);   // exclusive
250 
251     const Tensor* stats_summary_t;
252     OP_REQUIRES_OK(context, context->input("stats_summary", &stats_summary_t));
253     TTypes<float, 4>::ConstTensor stats_summary =
254         stats_summary_t->tensor<float, 4>();
255     const int32 feature_dims = stats_summary_t->dim_size(1);
256     // The last bucket is for default/missing value.
257     const int32 num_buckets = stats_summary_t->dim_size(2) - 1;
258     const int32 logits_dim = logits_dim_;
259     const int32 hessian_dim = stats_summary_t->dim_size(3) - logits_dim;
260     DCHECK_GT(hessian_dim, 0);
261     DCHECK_LE(hessian_dim, logits_dim * logits_dim);
262 
263     const Tensor* l1_t;
264     OP_REQUIRES_OK(context, context->input("l1", &l1_t));
265     const auto l1 = l1_t->scalar<float>()();
266     DCHECK_GE(l1, 0);
267     if (logits_dim_ > 1) {
268       // Multi-class L1 regularization not supported yet.
269       DCHECK_EQ(l1, 0);
270     }
271 
272     const Tensor* l2_t;
273     OP_REQUIRES_OK(context, context->input("l2", &l2_t));
274     const auto l2 = l2_t->scalar<float>()();
275     DCHECK_GE(l2, 0);
276 
277     const Tensor* tree_complexity_t;
278     OP_REQUIRES_OK(context,
279                    context->input("tree_complexity", &tree_complexity_t));
280     const auto tree_complexity = tree_complexity_t->scalar<float>()();
281 
282     const Tensor* min_node_weight_t;
283     OP_REQUIRES_OK(context,
284                    context->input("min_node_weight", &min_node_weight_t));
285     const auto min_node_weight = min_node_weight_t->scalar<float>()();
286 
287     std::vector<int32> output_node_ids;
288     std::vector<float> output_gains;
289     std::vector<int32> output_feature_dimensions;
290     std::vector<int32> output_thresholds;
291     std::vector<Eigen::VectorXf> output_left_node_contribs;
292     std::vector<Eigen::VectorXf> output_right_node_contribs;
293     std::vector<string> output_split_types;
294 
295     // TODO(tanzheny) parallelize the computation.
296     // Iterate each node and find the best gain per node.
297     for (int32 node_id = node_id_first; node_id < node_id_last; ++node_id) {
298       float best_gain = std::numeric_limits<float>::lowest();
299       int32 best_bucket = 0;
300       int32 best_f_dim = 0;
301       string best_split_type;
302       Eigen::VectorXf best_contrib_for_left(logits_dim);
303       Eigen::VectorXf best_contrib_for_right(logits_dim);
304       float parent_gain;
305 
306       // Including default bucket.
307       ConstMatrixMap stats_mat(&stats_summary(node_id, 0, 0, 0),
308                                num_buckets + 1, logits_dim + hessian_dim);
309       const Eigen::VectorXf total_grad =
310           stats_mat.leftCols(logits_dim).colwise().sum();
311       const Eigen::VectorXf total_hess =
312           stats_mat.rightCols(hessian_dim).colwise().sum();
313       if (total_hess.norm() < min_node_weight) {
314         continue;
315       }
316       Eigen::VectorXf parent_weight(logits_dim);
317       CalculateWeightsAndGains(total_grad, total_hess, l1, l2, &parent_weight,
318                                &parent_gain);
319 
320       if (split_type_ == "inequality") {
321         CalculateBestInequalitySplit(
322             stats_summary, node_id, feature_dims, logits_dim, hessian_dim,
323             num_buckets, min_node_weight, l1, l2, &best_gain, &best_bucket,
324             &best_f_dim, &best_split_type, &best_contrib_for_left,
325             &best_contrib_for_right);
326       } else {
327         CalculateBestEqualitySplit(
328             stats_summary, total_grad, total_hess, node_id, feature_dims,
329             logits_dim, hessian_dim, num_buckets, l1, l2, &best_gain,
330             &best_bucket, &best_f_dim, &best_split_type, &best_contrib_for_left,
331             &best_contrib_for_right);
332       }
333 
334       if (best_gain == std::numeric_limits<float>::lowest()) {
335         // Do not add the node if not split if found.
336         continue;
337       }
338       output_node_ids.push_back(node_id);
339       // Remove the parent gain for the parent node.
340       output_gains.push_back(best_gain - parent_gain);
341       output_feature_dimensions.push_back(best_f_dim);
342       // default direction is fixed for dense splits.
343       // TODO(tanzheny) account for default values.
344       output_split_types.push_back(best_split_type);
345       output_thresholds.push_back(best_bucket);
346       output_left_node_contribs.push_back(best_contrib_for_left);
347       output_right_node_contribs.push_back(best_contrib_for_right);
348     }  // for node id
349     const int num_nodes = output_node_ids.size();
350     // output_node_ids
351     Tensor* output_node_ids_t = nullptr;
352     OP_REQUIRES_OK(context, context->allocate_output("node_ids", {num_nodes},
353                                                      &output_node_ids_t));
354     auto output_node_ids_vec = output_node_ids_t->vec<int32>();
355 
356     // output_gains
357     Tensor* output_gains_t;
358     OP_REQUIRES_OK(context, context->allocate_output("gains", {num_nodes},
359                                                      &output_gains_t));
360     auto output_gains_vec = output_gains_t->vec<float>();
361 
362     // output_feature_dimensions
363     Tensor* output_feature_dimension_t;
364     OP_REQUIRES_OK(context,
365                    context->allocate_output("feature_dimensions", {num_nodes},
366                                             &output_feature_dimension_t));
367     auto output_feature_dimensions_vec =
368         output_feature_dimension_t->vec<int32>();
369 
370     // output_thresholds
371     Tensor* output_thresholds_t;
372     OP_REQUIRES_OK(context, context->allocate_output("thresholds", {num_nodes},
373                                                      &output_thresholds_t));
374     auto output_thresholds_vec = output_thresholds_t->vec<int32>();
375 
376     // output_left_node_contribs
377     Tensor* output_left_node_contribs_t;
378     OP_REQUIRES_OK(context, context->allocate_output(
379                                 "left_node_contribs", {num_nodes, logits_dim},
380                                 &output_left_node_contribs_t));
381     auto output_left_node_contribs_matrix =
382         output_left_node_contribs_t->matrix<float>();
383 
384     // output_right_node_contribs
385     Tensor* output_right_node_contribs_t;
386     OP_REQUIRES_OK(context, context->allocate_output(
387                                 "right_node_contribs", {num_nodes, logits_dim},
388                                 &output_right_node_contribs_t));
389     auto output_right_node_contribs_matrix =
390         output_right_node_contribs_t->matrix<float>();
391 
392     // split type
393     Tensor* output_split_types_t;
394     OP_REQUIRES_OK(
395         context, context->allocate_output("split_with_default_directions",
396                                           {num_nodes}, &output_split_types_t));
397     auto output_split_types_vec = output_split_types_t->vec<tstring>();
398 
399     // Sets output tensors from vectors.
400     for (int i = 0; i < num_nodes; ++i) {
401       output_node_ids_vec(i) = output_node_ids[i];
402       // Adjust the gains to penalize by tree complexity.
403       output_gains_vec(i) = output_gains[i] - tree_complexity;
404       output_feature_dimensions_vec(i) = output_feature_dimensions[i];
405       output_thresholds_vec(i) = output_thresholds[i];
406       for (int j = 0; j < logits_dim; ++j) {
407         output_left_node_contribs_matrix(i, j) =
408             output_left_node_contribs[i][j];
409         output_right_node_contribs_matrix(i, j) =
410             output_right_node_contribs[i][j];
411       }
412       output_split_types_vec(i) = output_split_types[i];
413     }
414   }
415 
416  private:
417   // TODO(crawles): Simplify inequality path just like equality b/138329196
418   // Currently this is not simplify-able due to numerical instability in math
419   // i.e. gain = -g.transpose() * hessian_and_reg.colPivHouseholderQr().solve(g)
420   // It caused gain to be Inf when g is approaching 0 but not exactly 0 while
421   // there is no regularization.
422   // Calculate the best inequality split per node.
CalculateBestInequalitySplit(TTypes<float,4>::ConstTensor stats_summary,const int32 node_id,const int32 feature_dims,const int32 logits_dim,const int32 hessian_dim,const int32 num_buckets,const float min_node_weight,const float l1,const float l2,float * best_gain,int32 * best_bucket,int32 * best_f_dim,string * best_split_type,Eigen::VectorXf * best_contrib_for_left,Eigen::VectorXf * best_contrib_for_right)423   void CalculateBestInequalitySplit(
424       TTypes<float, 4>::ConstTensor stats_summary, const int32 node_id,
425       const int32 feature_dims, const int32 logits_dim, const int32 hessian_dim,
426       const int32 num_buckets, const float min_node_weight, const float l1,
427       const float l2, float* best_gain, int32* best_bucket, int32* best_f_dim,
428       string* best_split_type, Eigen::VectorXf* best_contrib_for_left,
429       Eigen::VectorXf* best_contrib_for_right) {
430     std::vector<Eigen::VectorXf> cum_grad;
431     std::vector<Eigen::VectorXf> cum_hess;
432     // get all cumulative gradients including default bucket.
433     cum_grad.reserve(num_buckets);
434     cum_hess.reserve(num_buckets);
435 
436     for (int f_dim = 0; f_dim < feature_dims; ++f_dim) {
437       ConstVectorMap default_stats_vec(
438           &stats_summary(node_id, f_dim, num_buckets, 0),
439           logits_dim + hessian_dim);
440       Eigen::VectorXf missing_bucket_grad = default_stats_vec.head(logits_dim);
441       Eigen::VectorXf missing_bucket_hess = default_stats_vec.tail(hessian_dim);
442       cum_grad.clear();
443       cum_hess.clear();
444       Eigen::VectorXf total_grad = Eigen::VectorXf::Zero(logits_dim);
445       Eigen::VectorXf total_hess = Eigen::VectorXf::Zero(hessian_dim);
446       // sum all the gradients including default bucket.
447       for (int bucket = 0; bucket <= num_buckets; ++bucket) {
448         for (int i = 0; i < logits_dim; ++i) {
449           total_grad[i] += stats_summary(node_id, f_dim, bucket, i);
450         }
451         for (int i = 0; i < hessian_dim; ++i) {
452           // Full hessian.
453           total_hess[i] +=
454               stats_summary(node_id, f_dim, bucket, logits_dim + i);
455         }
456         if (bucket < num_buckets) {
457           cum_grad.push_back(total_grad);
458           cum_hess.push_back(total_hess);
459         }
460       }
461       const string kInequalityDefaultLeft =
462           boosted_trees::SplitTypeWithDefault_Name(
463               boosted_trees::INEQUALITY_DEFAULT_LEFT);
464       const string kInequalityDefaultRight =
465           boosted_trees::SplitTypeWithDefault_Name(
466               boosted_trees::INEQUALITY_DEFAULT_RIGHT);
467 
468       // Iterate from left to right, excluding default bucket.
469       for (int bucket = 0; bucket < num_buckets; ++bucket) {
470         // default value goes to left node.
471         const Eigen::VectorXf total_left_grad =
472             cum_grad[bucket] + missing_bucket_grad;
473         const Eigen::VectorXf total_left_hess =
474             cum_hess[bucket] + missing_bucket_hess;
475         MaybeUpdateBestSplit(
476             total_left_grad, total_grad - total_left_grad, total_left_hess,
477             total_hess - total_left_hess, logits_dim, bucket, f_dim, l1, l2,
478             kInequalityDefaultLeft, best_gain, best_bucket, best_f_dim,
479             best_split_type, best_contrib_for_left, best_contrib_for_right);
480         // default value goes to right node.
481         MaybeUpdateBestSplit(
482             cum_grad[bucket], total_grad - cum_grad[bucket], cum_hess[bucket],
483             total_hess - cum_hess[bucket], logits_dim, bucket, f_dim, l1, l2,
484             kInequalityDefaultRight, best_gain, best_bucket, best_f_dim,
485             best_split_type, best_contrib_for_left, best_contrib_for_right);
486       }  // for bucket
487     }
488   }
489 
490   // Calculate the best equality split per node.
CalculateBestEqualitySplit(TTypes<float,4>::ConstTensor stats_summary,const Eigen::VectorXf & total_grad,const Eigen::VectorXf & total_hess,const int32 node_id,const int32 feature_dims,const int32 logits_dim,const int32 hessian_dim,const int32 num_buckets,const float l1,const float l2,float * best_gain,int32 * best_bucket,int32 * best_f_dim,string * best_split_type,Eigen::VectorXf * best_contrib_for_left,Eigen::VectorXf * best_contrib_for_right)491   void CalculateBestEqualitySplit(
492       TTypes<float, 4>::ConstTensor stats_summary,
493       const Eigen::VectorXf& total_grad, const Eigen::VectorXf& total_hess,
494       const int32 node_id, const int32 feature_dims, const int32 logits_dim,
495       const int32 hessian_dim, const int32 num_buckets, const float l1,
496       const float l2, float* best_gain, int32* best_bucket, int32* best_f_dim,
497       string* best_split_type, Eigen::VectorXf* best_contrib_for_left,
498       Eigen::VectorXf* best_contrib_for_right) {
499     const string kEqualityDefaultRight =
500         boosted_trees::SplitTypeWithDefault_Name(
501             boosted_trees::EQUALITY_DEFAULT_RIGHT);
502     for (int f_dim = 0; f_dim < feature_dims; ++f_dim) {
503       for (int bucket = 0; bucket < num_buckets; ++bucket) {
504         ConstVectorMap stats_vec(&stats_summary(node_id, f_dim, bucket, 0),
505                                  logits_dim + hessian_dim);
506         Eigen::VectorXf curr_grad = stats_vec.head(logits_dim);
507         Eigen::VectorXf curr_hess = stats_vec.tail(hessian_dim);
508         MaybeUpdateBestSplit(curr_grad, total_grad - curr_grad, curr_hess,
509                              total_hess - curr_hess, logits_dim, bucket, f_dim,
510                              l1, l2, kEqualityDefaultRight, best_gain,
511                              best_bucket, best_f_dim, best_split_type,
512                              best_contrib_for_left, best_contrib_for_right);
513       }
514     }
515   }
516 
MaybeUpdateBestSplit(const Eigen::VectorXf & grad_for_left,const Eigen::VectorXf & grad_for_right,const Eigen::VectorXf & hess_for_left,const Eigen::VectorXf & hess_for_right,const int32 logits_dim,const int32 bucket,const int32 f_dim,const float l1,const float l2,const string split_type,float * best_gain,int32 * best_bucket,int32 * best_f_dim,string * best_split_type,Eigen::VectorXf * best_contrib_for_left,Eigen::VectorXf * best_contrib_for_right)517   void MaybeUpdateBestSplit(const Eigen::VectorXf& grad_for_left,
518                             const Eigen::VectorXf& grad_for_right,
519                             const Eigen::VectorXf& hess_for_left,
520                             const Eigen::VectorXf& hess_for_right,
521                             const int32 logits_dim, const int32 bucket,
522                             const int32 f_dim, const float l1, const float l2,
523                             const string split_type, float* best_gain,
524                             int32* best_bucket, int32* best_f_dim,
525                             string* best_split_type,
526                             Eigen::VectorXf* best_contrib_for_left,
527                             Eigen::VectorXf* best_contrib_for_right) {
528     // Left child.
529     Eigen::VectorXf contrib_for_left(logits_dim);
530     float gain_for_left;
531     CalculateWeightsAndGains(grad_for_left, hess_for_left, l1, l2,
532                              &contrib_for_left, &gain_for_left);
533     Eigen::VectorXf contrib_for_right(logits_dim);
534     float gain_for_right;
535     CalculateWeightsAndGains(grad_for_right, hess_for_right, l1, l2,
536                              &contrib_for_right, &gain_for_right);
537     if (GainIsLarger(gain_for_left + gain_for_right, *best_gain)) {
538       *best_gain = gain_for_left + gain_for_right;
539       *best_bucket = bucket;
540       *best_f_dim = f_dim;
541       *best_contrib_for_left = contrib_for_left;
542       *best_contrib_for_right = contrib_for_right;
543       *best_split_type = split_type;
544     }
545   }
546 
547   int logits_dim_;
548   string split_type_;
549 };
550 
551 // Deprecated op. Use BoostedTreesCalculateBestFeatureSplitOpV2.
552 REGISTER_KERNEL_BUILDER(
553     Name("BoostedTreesCalculateBestFeatureSplit").Device(DEVICE_CPU),
554     BoostedTreesCalculateBestFeatureSplitOp);
555 
556 // V2 Op.
557 class BoostedTreesCalculateBestFeatureSplitV2 : public OpKernel {
558  public:
BoostedTreesCalculateBestFeatureSplitV2(OpKernelConstruction * const context)559   explicit BoostedTreesCalculateBestFeatureSplitV2(
560       OpKernelConstruction* const context)
561       : OpKernel(context) {
562     OP_REQUIRES_OK(context, context->GetAttr("logits_dimension", &logits_dim_));
563     OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_));
564   }
565 
Compute(OpKernelContext * const context)566   void Compute(OpKernelContext* const context) override {
567     // node_id_range
568     const Tensor* node_id_range_t;
569     OP_REQUIRES_OK(context, context->input("node_id_range", &node_id_range_t));
570     const auto node_id_range = node_id_range_t->vec<int32>();
571     const int32 node_id_first = node_id_range(0);  // Inclusive.
572     const int32 node_id_last = node_id_range(1);   // Exclusive.
573 
574     // Get stats_summaries_list.
575     OpInputList stats_summaries_list;
576     OP_REQUIRES_OK(context, context->input_list("stats_summaries_list",
577                                                 &stats_summaries_list));
578 
579     // Infer dimensions of a stats_summary.
580     DCHECK_GT(stats_summaries_list.size(), 0);
581     const int32 feature_dims = stats_summaries_list[0].dim_size(1);
582     // The last bucket is for default/missing value.
583     const int32 num_buckets = stats_summaries_list[0].dim_size(2) - 1;
584     const int32 logits_dim = logits_dim_;
585     const int32 hessian_dim = stats_summaries_list[0].dim_size(3) - logits_dim;
586     DCHECK_GT(hessian_dim, 0);
587     DCHECK_LE(hessian_dim, logits_dim * logits_dim);
588 
589     // Vector of stats_summaries; each element is stats for feature of shape
590     // [max_splits, feature_dim, num_buckets, logits_dim + hessian_dim].
591     std::vector<TTypes<float, 4>::ConstTensor> stats_summaries;
592     DCHECK_EQ(stats_summaries_list.size(), num_features_);
593     stats_summaries.reserve(num_features_);
594     for (const auto& tensor : stats_summaries_list) {
595       stats_summaries.emplace_back(tensor.tensor<float, 4>());
596     }
597 
598     // Split types.
599     const Tensor* split_types_t;
600     OP_REQUIRES_OK(context, context->input("split_types", &split_types_t));
601     const auto split_types = split_types_t->vec<tstring>();
602     DCHECK_EQ(split_types.size(), num_features_);
603     // Validate.
604     for (int i = 0; i < num_features_; ++i) {
605       if (!(split_types(i) == kInequalitySplit ||
606             split_types(i) == kEqualitySplit)) {
607         OP_REQUIRES_OK(
608             context,
609             errors::Aborted(
610                 "Operation received an exception: Incorrect split type"));
611       }
612     }
613     // Feature ids.
614     const Tensor* candidate_feature_ids_t;
615     OP_REQUIRES_OK(context, context->input("candidate_feature_ids",
616                                            &candidate_feature_ids_t));
617     const auto candidate_feature_ids = candidate_feature_ids_t->vec<int32>();
618     DCHECK_EQ(candidate_feature_ids.size(), num_features_);
619 
620     // L1, L2, tree_complexity, min_node_weight.
621     const Tensor* l1_t;
622     OP_REQUIRES_OK(context, context->input("l1", &l1_t));
623     const auto l1 = l1_t->scalar<float>()();
624     DCHECK_GE(l1, 0);
625     if (logits_dim_ > 1) {
626       // Multi-class L1 regularization not supported yet.
627       DCHECK_EQ(l1, 0);
628     }
629     const Tensor* l2_t;
630     OP_REQUIRES_OK(context, context->input("l2", &l2_t));
631     const auto l2 = l2_t->scalar<float>()();
632     DCHECK_GE(l2, 0);
633     const Tensor* tree_complexity_t;
634     OP_REQUIRES_OK(context,
635                    context->input("tree_complexity", &tree_complexity_t));
636     const auto tree_complexity = tree_complexity_t->scalar<float>()();
637     const Tensor* min_node_weight_t;
638     OP_REQUIRES_OK(context,
639                    context->input("min_node_weight", &min_node_weight_t));
640     const auto min_node_weight = min_node_weight_t->scalar<float>()();
641 
642     std::vector<int32> output_node_ids;
643     std::vector<float> output_gains;
644     std::vector<int32> output_feature_ids;
645     std::vector<int32> output_feature_dimensions;
646     std::vector<int32> output_thresholds;
647     std::vector<Eigen::VectorXf> output_left_node_contribs;
648     std::vector<Eigen::VectorXf> output_right_node_contribs;
649     std::vector<string> output_split_types;
650 
651     // TODO(tanzheny) parallelize the computation.
652     // Iterate each node and find the best gain per node.
653     float parent_gain;
654     for (int32 node_id = node_id_first; node_id < node_id_last; ++node_id) {
655       float best_gain = std::numeric_limits<float>::lowest();
656       int32 best_bucket;
657       int32 best_f_id;
658       int32 best_f_dim;
659       string best_split_type;
660       Eigen::VectorXf best_contrib_for_left(logits_dim);
661       Eigen::VectorXf best_contrib_for_right(logits_dim);
662 
663       // Sum of gradient and hessian. Compute parent gain using first feature.
664       ConstMatrixMap stats_mat(&stats_summaries[0](node_id, 0, 0, 0),
665                                num_buckets + 1,  // Including default bucket.
666                                logits_dim + hessian_dim);
667       const Eigen::VectorXf total_grad =
668           stats_mat.leftCols(logits_dim).colwise().sum();
669       const Eigen::VectorXf total_hess =
670           stats_mat.rightCols(hessian_dim).colwise().sum();
671       if (total_hess.norm() < min_node_weight) {
672         continue;
673       }
674       Eigen::VectorXf unused(logits_dim);
675       CalculateWeightsAndGains(total_grad, total_hess, l1, l2, &unused,
676                                &parent_gain);
677       for (int f_idx = 0; f_idx < num_features_; ++f_idx) {
678         const string split_type = split_types(f_idx);
679         TTypes<float, 4>::ConstTensor stats_summary = stats_summaries[f_idx];
680         float f_best_gain = std::numeric_limits<float>::lowest();
681         int32 f_best_bucket;
682         int32 f_best_f_dim;
683         string f_best_split_type;
684         Eigen::VectorXf f_best_contrib_for_left(logits_dim);
685         Eigen::VectorXf f_best_contrib_for_right(logits_dim);
686 
687         if (split_type == kInequalitySplit) {
688           CalculateBestInequalitySplit(
689               stats_summary, node_id, feature_dims, logits_dim, hessian_dim,
690               num_buckets, min_node_weight, l1, l2, &f_best_gain,
691               &f_best_bucket, &f_best_f_dim, &f_best_split_type,
692               &f_best_contrib_for_left, &f_best_contrib_for_right);
693         } else {
694           CalculateBestEqualitySplit(
695               stats_summary, total_grad, total_hess, node_id, feature_dims,
696               logits_dim, hessian_dim, num_buckets, l1, l2, &f_best_gain,
697               &f_best_bucket, &f_best_f_dim, &f_best_split_type,
698               &f_best_contrib_for_left, &f_best_contrib_for_right);
699         }
700         if (f_best_gain > best_gain) {
701           best_gain = f_best_gain;
702           best_f_id = candidate_feature_ids(f_idx);
703           best_f_dim = f_best_f_dim;
704           best_split_type = f_best_split_type;
705           best_bucket = f_best_bucket;
706           best_contrib_for_left = f_best_contrib_for_left;
707           best_contrib_for_right = f_best_contrib_for_right;
708         }
709       }  // For feature id.
710       if (best_gain == std::numeric_limits<float>::lowest()) {
711         // Do not add the node if no split is found.
712         continue;
713       }
714       output_node_ids.push_back(node_id);
715       // Remove the parent gain for the parent node.
716       output_gains.push_back(best_gain - parent_gain);
717       output_feature_ids.push_back(best_f_id);
718       output_feature_dimensions.push_back(best_f_dim);
719       // Default direction is fixed for dense splits.
720       // TODO(tanzheny) account for default values.
721       output_split_types.push_back(best_split_type);
722       output_thresholds.push_back(best_bucket);
723       output_left_node_contribs.push_back(best_contrib_for_left);
724       output_right_node_contribs.push_back(best_contrib_for_right);
725     }  // for node id.
726     const int num_nodes = output_node_ids.size();
727     // output_node_ids
728     Tensor* output_node_ids_t = nullptr;
729     OP_REQUIRES_OK(context, context->allocate_output("node_ids", {num_nodes},
730                                                      &output_node_ids_t));
731     auto output_node_ids_vec = output_node_ids_t->vec<int32>();
732 
733     // output_gains
734     Tensor* output_gains_t;
735     OP_REQUIRES_OK(context, context->allocate_output("gains", {num_nodes},
736                                                      &output_gains_t));
737     auto output_gains_vec = output_gains_t->vec<float>();
738 
739     // output_feature_ids
740     Tensor* output_features_ids_t;
741     OP_REQUIRES_OK(context, context->allocate_output("feature_ids", {num_nodes},
742                                                      &output_features_ids_t));
743     auto output_features_vec = output_features_ids_t->vec<int32>();
744 
745     // output_feature_dimensions
746     Tensor* output_feature_dimension_t;
747     OP_REQUIRES_OK(context,
748                    context->allocate_output("feature_dimensions", {num_nodes},
749                                             &output_feature_dimension_t));
750     auto output_feature_dimensions_vec =
751         output_feature_dimension_t->vec<int32>();
752 
753     // output_thresholds
754     Tensor* output_thresholds_t;
755     OP_REQUIRES_OK(context, context->allocate_output("thresholds", {num_nodes},
756                                                      &output_thresholds_t));
757     auto output_thresholds_vec = output_thresholds_t->vec<int32>();
758 
759     // output_left_node_contribs
760     Tensor* output_left_node_contribs_t;
761     OP_REQUIRES_OK(context, context->allocate_output(
762                                 "left_node_contribs", {num_nodes, logits_dim},
763                                 &output_left_node_contribs_t));
764     auto output_left_node_contribs_matrix =
765         output_left_node_contribs_t->matrix<float>();
766 
767     // output_right_node_contribs
768     Tensor* output_right_node_contribs_t;
769     OP_REQUIRES_OK(context, context->allocate_output(
770                                 "right_node_contribs", {num_nodes, logits_dim},
771                                 &output_right_node_contribs_t));
772     auto output_right_node_contribs_matrix =
773         output_right_node_contribs_t->matrix<float>();
774 
775     // split type
776     Tensor* output_split_types_t;
777     OP_REQUIRES_OK(
778         context, context->allocate_output("split_with_default_directions",
779                                           {num_nodes}, &output_split_types_t));
780     auto output_split_types_vec = output_split_types_t->vec<tstring>();
781 
782     // Sets output tensors from vectors.
783     for (int i = 0; i < num_nodes; ++i) {
784       output_node_ids_vec(i) = output_node_ids[i];
785       output_features_vec(i) = output_feature_ids[i];
786       // Adjust the gains to penalize by tree complexity.
787       output_gains_vec(i) = output_gains[i] - tree_complexity;
788       output_feature_dimensions_vec(i) = output_feature_dimensions[i];
789       output_thresholds_vec(i) = output_thresholds[i];
790       for (int j = 0; j < logits_dim; ++j) {
791         output_left_node_contribs_matrix(i, j) =
792             output_left_node_contribs[i][j];
793         output_right_node_contribs_matrix(i, j) =
794             output_right_node_contribs[i][j];
795       }
796       output_split_types_vec(i) = output_split_types[i];
797     }
798   }
799 
800  private:
801   // TODO(crawles): Simplify inequality path just like equality b/138329196
802   // Currently this is not simplify-able due to numerical instability in math
803   // i.e. gain = -g.transpose() * hessian_and_reg.colPivHouseholderQr().solve(g)
804   // It caused gain to be Inf when g is approaching 0 but not exactly 0 while
805   // there is no regularization.
806   // Calculate the best inequality split per node.
CalculateBestInequalitySplit(TTypes<float,4>::ConstTensor stats_summary,const int32 node_id,const int32 feature_dims,const int32 logits_dim,const int32 hessian_dim,const int32 num_buckets,const float min_node_weight,const float l1,const float l2,float * best_gain,int32 * best_bucket,int32 * best_f_dim,string * best_split_type,Eigen::VectorXf * best_contrib_for_left,Eigen::VectorXf * best_contrib_for_right)807   void CalculateBestInequalitySplit(
808       TTypes<float, 4>::ConstTensor stats_summary, const int32 node_id,
809       const int32 feature_dims, const int32 logits_dim, const int32 hessian_dim,
810       const int32 num_buckets, const float min_node_weight, const float l1,
811       const float l2, float* best_gain, int32* best_bucket, int32* best_f_dim,
812       string* best_split_type, Eigen::VectorXf* best_contrib_for_left,
813       Eigen::VectorXf* best_contrib_for_right) {
814     std::vector<Eigen::VectorXf> cum_grad;
815     std::vector<Eigen::VectorXf> cum_hess;
816     // get all cumulative gradients including default bucket.
817     cum_grad.reserve(num_buckets);
818     cum_hess.reserve(num_buckets);
819 
820     for (int f_dim = 0; f_dim < feature_dims; ++f_dim) {
821       ConstVectorMap default_stats_vec(
822           &stats_summary(node_id, f_dim, num_buckets, 0),
823           logits_dim + hessian_dim);
824       Eigen::VectorXf missing_bucket_grad = default_stats_vec.head(logits_dim);
825       Eigen::VectorXf missing_bucket_hess = default_stats_vec.tail(hessian_dim);
826       cum_grad.clear();
827       cum_hess.clear();
828       Eigen::VectorXf total_grad = Eigen::VectorXf::Zero(logits_dim);
829       Eigen::VectorXf total_hess = Eigen::VectorXf::Zero(hessian_dim);
830       // sum all the gradients including default bucket.
831       for (int bucket = 0; bucket <= num_buckets; ++bucket) {
832         for (int i = 0; i < logits_dim; ++i) {
833           total_grad[i] += stats_summary(node_id, f_dim, bucket, i);
834         }
835         for (int i = 0; i < hessian_dim; ++i) {
836           // Full hessian.
837           total_hess[i] +=
838               stats_summary(node_id, f_dim, bucket, logits_dim + i);
839         }
840         if (bucket < num_buckets) {
841           cum_grad.push_back(total_grad);
842           cum_hess.push_back(total_hess);
843         }
844       }
845       const string kInequalityDefaultLeft =
846           boosted_trees::SplitTypeWithDefault_Name(
847               boosted_trees::INEQUALITY_DEFAULT_LEFT);
848       const string kInequalityDefaultRight =
849           boosted_trees::SplitTypeWithDefault_Name(
850               boosted_trees::INEQUALITY_DEFAULT_RIGHT);
851 
852       // Iterate from left to right, excluding default bucket.
853       for (int bucket = 0; bucket < num_buckets; ++bucket) {
854         // default value goes to left node.
855         const Eigen::VectorXf total_left_grad =
856             cum_grad[bucket] + missing_bucket_grad;
857         const Eigen::VectorXf total_left_hess =
858             cum_hess[bucket] + missing_bucket_hess;
859         MaybeUpdateBestSplit(
860             total_left_grad, total_grad - total_left_grad, total_left_hess,
861             total_hess - total_left_hess, logits_dim, bucket, f_dim, l1, l2,
862             kInequalityDefaultLeft, best_gain, best_bucket, best_f_dim,
863             best_split_type, best_contrib_for_left, best_contrib_for_right);
864         // default value goes to right node.
865         MaybeUpdateBestSplit(
866             cum_grad[bucket], total_grad - cum_grad[bucket], cum_hess[bucket],
867             total_hess - cum_hess[bucket], logits_dim, bucket, f_dim, l1, l2,
868             kInequalityDefaultRight, best_gain, best_bucket, best_f_dim,
869             best_split_type, best_contrib_for_left, best_contrib_for_right);
870       }  // for bucket
871     }
872   }
873 
874   // Calculate the best equality split per node.
CalculateBestEqualitySplit(TTypes<float,4>::ConstTensor stats_summary,const Eigen::VectorXf & total_grad,const Eigen::VectorXf & total_hess,const int32 node_id,const int32 feature_dims,const int32 logits_dim,const int32 hessian_dim,const int32 num_buckets,const float l1,const float l2,float * best_gain,int32 * best_bucket,int32 * best_f_dim,string * best_split_type,Eigen::VectorXf * best_contrib_for_left,Eigen::VectorXf * best_contrib_for_right)875   void CalculateBestEqualitySplit(
876       TTypes<float, 4>::ConstTensor stats_summary,
877       const Eigen::VectorXf& total_grad, const Eigen::VectorXf& total_hess,
878       const int32 node_id, const int32 feature_dims, const int32 logits_dim,
879       const int32 hessian_dim, const int32 num_buckets, const float l1,
880       const float l2, float* best_gain, int32* best_bucket, int32* best_f_dim,
881       string* best_split_type, Eigen::VectorXf* best_contrib_for_left,
882       Eigen::VectorXf* best_contrib_for_right) {
883     const string kEqualityDefaultRight =
884         boosted_trees::SplitTypeWithDefault_Name(
885             boosted_trees::EQUALITY_DEFAULT_RIGHT);
886     for (int f_dim = 0; f_dim < feature_dims; ++f_dim) {
887       for (int bucket = 0; bucket < num_buckets; ++bucket) {
888         ConstVectorMap stats_vec(&stats_summary(node_id, f_dim, bucket, 0),
889                                  logits_dim + hessian_dim);
890         Eigen::VectorXf curr_grad = stats_vec.head(logits_dim);
891         Eigen::VectorXf curr_hess = stats_vec.tail(hessian_dim);
892         MaybeUpdateBestSplit(curr_grad, total_grad - curr_grad, curr_hess,
893                              total_hess - curr_hess, logits_dim, bucket, f_dim,
894                              l1, l2, kEqualityDefaultRight, best_gain,
895                              best_bucket, best_f_dim, best_split_type,
896                              best_contrib_for_left, best_contrib_for_right);
897       }
898     }
899   }
900 
MaybeUpdateBestSplit(const Eigen::VectorXf & grad_for_left,const Eigen::VectorXf & grad_for_right,const Eigen::VectorXf & hess_for_left,const Eigen::VectorXf & hess_for_right,const int32 logits_dim,const int32 bucket,const int32 f_dim,const float l1,const float l2,const string split_type,float * best_gain,int32 * best_bucket,int32 * best_f_dim,string * best_split_type,Eigen::VectorXf * best_contrib_for_left,Eigen::VectorXf * best_contrib_for_right)901   void MaybeUpdateBestSplit(const Eigen::VectorXf& grad_for_left,
902                             const Eigen::VectorXf& grad_for_right,
903                             const Eigen::VectorXf& hess_for_left,
904                             const Eigen::VectorXf& hess_for_right,
905                             const int32 logits_dim, const int32 bucket,
906                             const int32 f_dim, const float l1, const float l2,
907                             const string split_type, float* best_gain,
908                             int32* best_bucket, int32* best_f_dim,
909                             string* best_split_type,
910                             Eigen::VectorXf* best_contrib_for_left,
911                             Eigen::VectorXf* best_contrib_for_right) {
912     // Left child.
913     Eigen::VectorXf contrib_for_left(logits_dim);
914     float gain_for_left;
915     CalculateWeightsAndGains(grad_for_left, hess_for_left, l1, l2,
916                              &contrib_for_left, &gain_for_left);
917     Eigen::VectorXf contrib_for_right(logits_dim);
918     float gain_for_right;
919     CalculateWeightsAndGains(grad_for_right, hess_for_right, l1, l2,
920                              &contrib_for_right, &gain_for_right);
921     if (GainIsLarger(gain_for_left + gain_for_right, *best_gain)) {
922       *best_gain = gain_for_left + gain_for_right;
923       *best_bucket = bucket;
924       *best_f_dim = f_dim;
925       *best_contrib_for_left = contrib_for_left;
926       *best_contrib_for_right = contrib_for_right;
927       *best_split_type = split_type;
928     }
929   }
930   int num_features_;
931   int logits_dim_;
932 };
933 
934 // v2 op that supports multi-class.
935 REGISTER_KERNEL_BUILDER(
936     Name("BoostedTreesCalculateBestFeatureSplitV2").Device(DEVICE_CPU),
937     BoostedTreesCalculateBestFeatureSplitV2);
938 
939 // Map from bucket id to vector of statistics.
940 typedef std::map<int32, std::vector<float>> BucketMap;
941 typedef BucketMap::iterator BucketMapIterator;
942 // Map from feature dimension to BucketMap.
943 typedef std::map<int32, BucketMap> FeatureMap;
944 typedef FeatureMap::iterator FeatureMapIterator;
945 
946 class BoostedTreesSparseCalculateBestFeatureSplitOp : public OpKernel {
947  public:
BoostedTreesSparseCalculateBestFeatureSplitOp(OpKernelConstruction * const context)948   explicit BoostedTreesSparseCalculateBestFeatureSplitOp(
949       OpKernelConstruction* const context)
950       : OpKernel(context) {
951     // TODO(crawles): Using logits_dim_ for multi-class split.
952     OP_REQUIRES_OK(context, context->GetAttr("logits_dimension", &logits_dim_));
953     // TODO(tanzheny): Using this for equality split.
954     OP_REQUIRES_OK(context, context->GetAttr("split_type", &split_type_));
955   }
956 
Compute(OpKernelContext * const context)957   void Compute(OpKernelContext* const context) override {
958     // node_id_range
959     const Tensor* node_id_range_t;
960     OP_REQUIRES_OK(context, context->input("node_id_range", &node_id_range_t));
961     const auto node_id_range = node_id_range_t->vec<int32>();
962     const int32 node_id_first = node_id_range(0);  // inclusive
963     const int32 node_id_last = node_id_range(1);   // exclusive
964 
965     const Tensor* stats_summary_indices_t;
966     OP_REQUIRES_OK(context, context->input("stats_summary_indices",
967                                            &stats_summary_indices_t));
968     const auto stats_summary_indices = stats_summary_indices_t->matrix<int32>();
969     const int32 num_sparse_entries = stats_summary_indices_t->dim_size(0);
970 
971     const Tensor* stats_summary_values_t;
972     OP_REQUIRES_OK(context, context->input("stats_summary_values",
973                                            &stats_summary_values_t));
974     const auto stats_summary_values = stats_summary_values_t->vec<float>();
975 
976     const Tensor* stats_summary_shape_t;
977     OP_REQUIRES_OK(
978         context, context->input("stats_summary_shape", &stats_summary_shape_t));
979     const auto stats_summary_shape = stats_summary_shape_t->vec<int32>();
980     const int32 num_buckets = stats_summary_shape(2) - 1;
981     const int32 stats_dims = stats_summary_shape(3);
982 
983     const Tensor* l1_t;
984     OP_REQUIRES_OK(context, context->input("l1", &l1_t));
985     const auto l1 = l1_t->scalar<float>()();
986 
987     const Tensor* l2_t;
988     OP_REQUIRES_OK(context, context->input("l2", &l2_t));
989     const auto l2 = l2_t->scalar<float>()();
990 
991     const Tensor* tree_complexity_t;
992     OP_REQUIRES_OK(context,
993                    context->input("tree_complexity", &tree_complexity_t));
994     const auto tree_complexity = tree_complexity_t->scalar<float>()();
995 
996     const Tensor* min_node_weight_t;
997     OP_REQUIRES_OK(context,
998                    context->input("min_node_weight", &min_node_weight_t));
999     const auto min_node_weight = min_node_weight_t->scalar<float>()();
1000 
1001     std::vector<int32> output_node_ids;
1002     std::vector<float> output_gains;
1003     std::vector<int32> output_feature_dimensions;
1004     std::vector<int32> output_thresholds;
1005     std::vector<float> output_left_node_contribs;
1006     std::vector<float> output_right_node_contribs;
1007     std::vector<string> output_split_types;
1008 
1009     FeatureMap f_map;
1010 
1011     int32 previous_node_id = -1;
1012     for (int idx = 0; idx < num_sparse_entries; ++idx) {
1013       int32 node_id = stats_summary_indices(idx, 0);
1014       if (node_id != previous_node_id) {
1015         process_node(f_map, &output_node_ids, &output_gains,
1016                      &output_feature_dimensions, &output_thresholds,
1017                      &output_left_node_contribs, &output_right_node_contribs,
1018                      &output_split_types, previous_node_id, min_node_weight, l1,
1019                      l2, num_buckets);
1020         f_map.clear();
1021       }
1022       previous_node_id = node_id;
1023       DCHECK_LE(node_id_first, node_id);
1024       DCHECK_LT(node_id, node_id_last);
1025       const int32 feature_dim = stats_summary_indices(idx, 1);
1026       const int32 bucket_id = stats_summary_indices(idx, 2);
1027       const int32 stat_dim = stats_summary_indices(idx, 3);
1028       std::pair<FeatureMapIterator, bool> const& f_insert_result = f_map.insert(
1029           FeatureMapIterator::value_type(feature_dim, BucketMap()));
1030       auto& b_map = f_insert_result.first->second;
1031       std::pair<BucketMapIterator, bool> const& b_insert_result =
1032           b_map.insert(BucketMapIterator::value_type(
1033               bucket_id, std::vector<float>(stats_dims)));
1034       auto& stats = b_insert_result.first->second;
1035       stats[stat_dim] = stats_summary_values(idx);
1036     }  // for node_id
1037     // process the last node id
1038     process_node(f_map, &output_node_ids, &output_gains,
1039                  &output_feature_dimensions, &output_thresholds,
1040                  &output_left_node_contribs, &output_right_node_contribs,
1041                  &output_split_types, previous_node_id, min_node_weight, l1, l2,
1042                  num_buckets);
1043 
1044     const int num_nodes = output_node_ids.size();
1045     // output_node_ids
1046     Tensor* output_node_ids_t = nullptr;
1047     OP_REQUIRES_OK(context, context->allocate_output("node_ids", {num_nodes},
1048                                                      &output_node_ids_t));
1049     auto output_node_ids_vec = output_node_ids_t->vec<int32>();
1050 
1051     // output_gains
1052     Tensor* output_gains_t;
1053     OP_REQUIRES_OK(context, context->allocate_output("gains", {num_nodes},
1054                                                      &output_gains_t));
1055     auto output_gains_vec = output_gains_t->vec<float>();
1056 
1057     // output_feature_dimensions
1058     Tensor* output_feature_dimension_t;
1059     OP_REQUIRES_OK(context,
1060                    context->allocate_output("feature_dimensions", {num_nodes},
1061                                             &output_feature_dimension_t));
1062     auto output_feature_dimensions_vec =
1063         output_feature_dimension_t->vec<int32>();
1064 
1065     // output_thresholds
1066     Tensor* output_thresholds_t;
1067     OP_REQUIRES_OK(context, context->allocate_output("thresholds", {num_nodes},
1068                                                      &output_thresholds_t));
1069     auto output_thresholds_vec = output_thresholds_t->vec<int32>();
1070 
1071     // output_left_node_contribs
1072     Tensor* output_left_node_contribs_t;
1073     OP_REQUIRES_OK(
1074         context, context->allocate_output("left_node_contribs", {num_nodes, 1},
1075                                           &output_left_node_contribs_t));
1076     auto output_left_node_contribs_matrix =
1077         output_left_node_contribs_t->matrix<float>();
1078 
1079     // output_right_node_contribs
1080     Tensor* output_right_node_contribs_t;
1081     OP_REQUIRES_OK(
1082         context, context->allocate_output("right_node_contribs", {num_nodes, 1},
1083                                           &output_right_node_contribs_t));
1084     auto output_right_node_contribs_matrix =
1085         output_right_node_contribs_t->matrix<float>();
1086 
1087     // split type
1088     Tensor* output_split_types_t;
1089     OP_REQUIRES_OK(
1090         context, context->allocate_output("split_with_default_directions",
1091                                           {num_nodes}, &output_split_types_t));
1092     auto output_split_types_vec = output_split_types_t->vec<tstring>();
1093 
1094     // Sets output tensors from vectors.
1095     for (int i = 0; i < num_nodes; ++i) {
1096       output_node_ids_vec(i) = output_node_ids[i];
1097       // Adjust the gains to penalize by tree complexity.
1098       output_gains_vec(i) = output_gains[i] - tree_complexity;
1099       output_feature_dimensions_vec(i) = output_feature_dimensions[i];
1100       output_thresholds_vec(i) = output_thresholds[i];
1101       // TODO(crawles): change this for multi-class.
1102       output_left_node_contribs_matrix(i, 0) = output_left_node_contribs[i];
1103       output_right_node_contribs_matrix(i, 0) = output_right_node_contribs[i];
1104       output_split_types_vec(i) = output_split_types[i];
1105     }
1106   }
1107 
1108  protected:
process_node(const FeatureMap & f_map,std::vector<int32> * output_node_ids,std::vector<float> * output_gains,std::vector<int32> * output_feature_dimensions,std::vector<int32> * output_thresholds,std::vector<float> * output_left_node_contribs,std::vector<float> * output_right_node_contribs,std::vector<string> * output_split_types,const int32 node_id,const float min_node_weight,const float l1,const float l2,const int32 num_buckets)1109   void process_node(const FeatureMap& f_map,
1110                     std::vector<int32>* output_node_ids,
1111                     std::vector<float>* output_gains,
1112                     std::vector<int32>* output_feature_dimensions,
1113                     std::vector<int32>* output_thresholds,
1114                     std::vector<float>* output_left_node_contribs,
1115                     std::vector<float>* output_right_node_contribs,
1116                     std::vector<string>* output_split_types,
1117                     const int32 node_id, const float min_node_weight,
1118                     const float l1, const float l2, const int32 num_buckets) {
1119     float parent_gain;
1120     Eigen::VectorXf unused(logits_dim_);
1121     Eigen::MatrixXf identity;
1122     identity.setIdentity(1, 1);
1123 
1124     // start processing for previous node id.
1125     float best_gain = std::numeric_limits<float>::lowest();
1126     float best_bucket = 0;
1127     float best_f_dim = 0;
1128     string best_split_type = boosted_trees::SplitTypeWithDefault_Name(
1129         boosted_trees::INEQUALITY_DEFAULT_LEFT);
1130     float best_contrib_for_left = 0.0;
1131     float best_contrib_for_right = 0.0;
1132     // the sum of gradients including default bucket.
1133     float total_grad = 0;
1134     // the sum of hessians including default bucket.
1135     float total_hess = 0;
1136 
1137     for (auto f_iter = f_map.begin(); f_iter != f_map.end(); ++f_iter) {
1138       const int32 feature_dim = f_iter->first;
1139       const auto buckets_to_stats_map = f_iter->second;
1140 
1141       // The very last bucket contains stats for missing values.
1142       // TODO(crawles): use vector for multi-class.
1143       const float default_grad =
1144           (buckets_to_stats_map.find(num_buckets) == buckets_to_stats_map.end()
1145                ? 0
1146                : buckets_to_stats_map.at(num_buckets)[0]);
1147       const float default_hess =
1148           (buckets_to_stats_map.find(num_buckets) == buckets_to_stats_map.end()
1149                ? 0
1150                : buckets_to_stats_map.at(num_buckets)[1]);
1151 
1152       if (f_iter == f_map.begin()) {
1153         // first get the sum of grads, including default bucket.
1154         for (auto b_iter = buckets_to_stats_map.begin();
1155              b_iter != buckets_to_stats_map.end(); ++b_iter) {
1156           total_grad += b_iter->second[0];
1157           total_hess += b_iter->second[1];
1158         }
1159         if (total_hess < min_node_weight) {
1160           // Do not split the node because not enough avg hessian.
1161           break;
1162         }
1163         CalculateWeightsAndGains(total_grad * identity, total_hess * identity,
1164                                  l1, l2, &unused, &parent_gain);
1165       }
1166 
1167       float total_left_grad = 0;
1168       float total_left_hess = 0;
1169       for (auto b_iter = buckets_to_stats_map.begin();
1170            b_iter != buckets_to_stats_map.end(); ++b_iter) {
1171         const int32 bucket_id = b_iter->first;
1172         // total_left_stats should exclude stats from default bucket.
1173         if (bucket_id == num_buckets) {
1174           break;
1175         }
1176         // TODO(crawles): vector for multi-class.
1177         total_left_grad += b_iter->second[0];
1178         total_left_hess += b_iter->second[1];
1179         // From left to right, default right.
1180         // Left child.
1181         Eigen::VectorXf contrib_for_left(1);
1182         float gain_for_left;
1183         CalculateWeightsAndGains(total_left_grad * identity,
1184                                  total_left_hess * identity, l1, l2,
1185                                  &contrib_for_left, &gain_for_left);
1186         // Right child.
1187         Eigen::VectorXf contrib_for_right(1);
1188         float gain_for_right;
1189         CalculateWeightsAndGains((total_grad - total_left_grad) * identity,
1190                                  (total_hess - total_left_hess) * identity, l1,
1191                                  l2, &contrib_for_right, &gain_for_right);
1192         if (GainIsLarger(gain_for_left + gain_for_right, best_gain)) {
1193           best_gain = gain_for_left + gain_for_right;
1194           best_bucket = bucket_id;
1195           best_f_dim = feature_dim;
1196           best_split_type = boosted_trees::SplitTypeWithDefault_Name(
1197               boosted_trees::INEQUALITY_DEFAULT_RIGHT);
1198           best_contrib_for_left = contrib_for_left[0];
1199           best_contrib_for_right = contrib_for_right[0];
1200         }
1201 
1202         // From right to left, default left.
1203         CalculateWeightsAndGains((total_left_grad + default_grad) * identity,
1204                                  (total_left_hess + default_hess) * identity,
1205                                  l1, l2, &contrib_for_left, &gain_for_left);
1206         CalculateWeightsAndGains(
1207             (total_grad - default_grad - total_left_grad) * identity,
1208             (total_hess - default_hess - total_left_hess) * identity, l1, l2,
1209             &contrib_for_right, &gain_for_right);
1210         if (GainIsLarger(gain_for_left + gain_for_right, best_gain)) {
1211           best_gain = gain_for_left + gain_for_right;
1212           best_bucket = bucket_id;
1213           best_f_dim = feature_dim;
1214           best_split_type = boosted_trees::SplitTypeWithDefault_Name(
1215               boosted_trees::INEQUALITY_DEFAULT_LEFT);
1216           best_contrib_for_left = contrib_for_left[0];
1217           best_contrib_for_right = contrib_for_right[0];
1218         }
1219       }  // for bucket_id
1220     }    // for feature_dim
1221     if (best_gain != std::numeric_limits<float>::lowest()) {
1222       output_node_ids->push_back(node_id);
1223       // Remove the parent gain.
1224       output_gains->push_back(best_gain - parent_gain);
1225       output_feature_dimensions->push_back(best_f_dim);
1226       output_split_types->push_back(best_split_type);
1227       output_thresholds->push_back(best_bucket);
1228       output_left_node_contribs->push_back(best_contrib_for_left);
1229       output_right_node_contribs->push_back(best_contrib_for_right);
1230     }
1231   }
1232 
1233  private:
1234   int logits_dim_;
1235   string split_type_;
1236 };
1237 
1238 REGISTER_KERNEL_BUILDER(
1239     Name("BoostedTreesSparseCalculateBestFeatureSplit").Device(DEVICE_CPU),
1240     BoostedTreesSparseCalculateBestFeatureSplitOp);
1241 
1242 class BoostedTreesMakeStatsSummaryOp : public OpKernel {
1243  public:
BoostedTreesMakeStatsSummaryOp(OpKernelConstruction * const context)1244   explicit BoostedTreesMakeStatsSummaryOp(OpKernelConstruction* const context)
1245       : OpKernel(context) {
1246     OP_REQUIRES_OK(context, context->GetAttr("max_splits", &max_splits_));
1247     OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
1248     OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_));
1249   }
1250 
Compute(OpKernelContext * const context)1251   void Compute(OpKernelContext* const context) override {
1252     // node_ids
1253     const Tensor* node_ids_t;
1254     OP_REQUIRES_OK(context, context->input("node_ids", &node_ids_t));
1255     const auto node_ids = node_ids_t->vec<int32>();
1256     // gradients
1257     const Tensor* gradients_t;
1258     OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
1259     const auto gradients = gradients_t->matrix<float>();
1260     // hessians
1261     const Tensor* hessians_t;
1262     OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
1263     const auto hessians = hessians_t->matrix<float>();
1264     // bucketized_features
1265     OpInputList bucketized_features_list;
1266     OP_REQUIRES_OK(context, context->input_list("bucketized_features_list",
1267                                                 &bucketized_features_list));
1268     // Infer batch size.
1269     const int64 batch_size = node_ids_t->dim_size(0);
1270 
1271     // Allocate temporary stats tensor (Rank 4).
1272     Tensor temp_stats_double_t;
1273     OP_REQUIRES_OK(context, context->allocate_temp(
1274                                 DT_DOUBLE,
1275                                 {num_features_, max_splits_, num_buckets_, 2},
1276                                 &temp_stats_double_t));
1277     auto temp_stats_double = temp_stats_double_t.tensor<double, 4>();
1278     temp_stats_double.setZero();
1279 
1280     // Partition by node, and then bucketize.
1281     for (int feature_idx = 0; feature_idx < num_features_; ++feature_idx) {
1282       const auto& features = bucketized_features_list[feature_idx].vec<int32>();
1283       for (int i = 0; i < batch_size; ++i) {
1284         const int32 node = node_ids(i);
1285         const int32 bucket = features(i);
1286         temp_stats_double(feature_idx, node, bucket, 0) += gradients(i, 0);
1287         temp_stats_double(feature_idx, node, bucket, 1) += hessians(i, 0);
1288       }
1289     }
1290 
1291     // Copy temp tensor over to output tensor.
1292     Tensor* output_stats_summary_t = nullptr;
1293     OP_REQUIRES_OK(context, context->allocate_output(
1294                                 "stats_summary", temp_stats_double_t.shape(),
1295                                 &output_stats_summary_t));
1296     output_stats_summary_t->tensor<float, 4>() =
1297         temp_stats_double.template cast<float>();
1298   }
1299 
1300  private:
1301   int max_splits_;
1302   int num_buckets_;
1303   int num_features_;
1304 };
1305 
1306 REGISTER_KERNEL_BUILDER(Name("BoostedTreesMakeStatsSummary").Device(DEVICE_CPU),
1307                         BoostedTreesMakeStatsSummaryOp);
1308 
1309 // TODO(tanzheny): Add an option of default value into the API interface.
1310 class BoostedTreesAggregateStatsOp : public OpKernel {
1311  public:
BoostedTreesAggregateStatsOp(OpKernelConstruction * const context)1312   explicit BoostedTreesAggregateStatsOp(OpKernelConstruction* const context)
1313       : OpKernel(context) {
1314     OP_REQUIRES_OK(context, context->GetAttr("max_splits", &max_splits_));
1315     OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
1316   }
1317 
Compute(OpKernelContext * const context)1318   void Compute(OpKernelContext* const context) override {
1319     // node_ids.
1320     const Tensor* node_ids_t;
1321     OP_REQUIRES_OK(context, context->input("node_ids", &node_ids_t));
1322     const auto node_ids = node_ids_t->vec<int32>();
1323 
1324     // gradients.
1325     const Tensor* gradients_t;
1326     OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
1327     const auto gradients = gradients_t->matrix<float>();
1328 
1329     // hessians.
1330     const Tensor* hessians_t;
1331     OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
1332     const auto hessians = hessians_t->matrix<float>();
1333 
1334     // feature.
1335     const Tensor* feature_t;
1336     OP_REQUIRES_OK(context, context->input("feature", &feature_t));
1337     const auto feature = feature_t->matrix<int32>();
1338 
1339     // Infer batch size, feature dimension and stats dimension.
1340     const int64 batch_size = node_ids_t->dim_size(0);
1341     const int64 logits_dims = gradients_t->dim_size(1);
1342     const int64 hessians_dims = hessians_t->dim_size(1);
1343     const int64 stats_dims = logits_dims + hessians_dims;
1344     const int64 feature_dims = feature_t->dim_size(1);
1345 
1346     // Allocate temporary stats tensor (Rank 4), upcasting to double.
1347     // A default bucket is added to the end for missing/default values.
1348     Tensor temp_stats_double_t;
1349     OP_REQUIRES_OK(
1350         context, context->allocate_temp(
1351                      DT_DOUBLE,
1352                      {max_splits_, feature_dims, num_buckets_ + 1, stats_dims},
1353                      &temp_stats_double_t));
1354     auto temp_stats_double = temp_stats_double_t.tensor<double, 4>();
1355     temp_stats_double.setZero();
1356 
1357     for (int i = 0; i < batch_size; ++i) {
1358       const int32 node = node_ids(i);
1359       for (int feature_dim = 0; feature_dim < feature_dims; ++feature_dim) {
1360         const int32 feature_value = feature(i, feature_dim);
1361         const int32 bucket =
1362             (feature_value == -1) ? num_buckets_ : feature_value;
1363         for (int stat_dim = 0; stat_dim < logits_dims; ++stat_dim) {
1364           temp_stats_double(node, feature_dim, bucket, stat_dim) +=
1365               gradients(i, stat_dim);
1366         }
1367         for (int stat_dim = logits_dims; stat_dim < stats_dims; ++stat_dim) {
1368           temp_stats_double(node, feature_dim, bucket, stat_dim) +=
1369               hessians(i, stat_dim - logits_dims);
1370         }
1371       }
1372     }
1373 
1374     // Copy temp tensor over to output tensor, downcasting to float.
1375     Tensor* output_stats_summary_t = nullptr;
1376     OP_REQUIRES_OK(context, context->allocate_output(
1377                                 "stats_summary", temp_stats_double_t.shape(),
1378                                 &output_stats_summary_t));
1379     output_stats_summary_t->tensor<float, 4>() =
1380         temp_stats_double.template cast<float>();
1381   }
1382 
1383  private:
1384   int max_splits_;
1385   int num_buckets_;
1386 };
1387 
1388 REGISTER_KERNEL_BUILDER(Name("BoostedTreesAggregateStats").Device(DEVICE_CPU),
1389                         BoostedTreesAggregateStatsOp);
1390 
1391 // Key based on node id, feature dimension and bucket id.
1392 struct StatsPartitionKey {
StatsPartitionKeytensorflow::StatsPartitionKey1393   StatsPartitionKey(const int32 node_id, const int32 feature_dim,
1394                     const int32 bucket_id)
1395       : node_id(node_id), feature_dim(feature_dim), bucket_id(bucket_id) {}
1396 
operator ==tensorflow::StatsPartitionKey1397   bool operator==(const StatsPartitionKey& other) const {
1398     return (node_id == other.node_id) && (feature_dim == other.feature_dim) &&
1399            (bucket_id == other.bucket_id);
1400   }
1401 
1402   // Compare for StatsPartitionKey.
1403   struct Less {
operator ()tensorflow::StatsPartitionKey::Less1404     bool operator()(const StatsPartitionKey& a,
1405                     const StatsPartitionKey& b) const {
1406       if (a.node_id < b.node_id) {
1407         return true;
1408       }
1409       if ((a.node_id == b.node_id) && (a.feature_dim < b.feature_dim)) {
1410         return true;
1411       }
1412       if ((a.node_id == b.node_id) && (a.feature_dim == b.feature_dim) &&
1413           (a.bucket_id < b.bucket_id)) {
1414         return true;
1415       }
1416       return false;
1417     }
1418   };
1419 
1420   // Tree node id.
1421   int32 node_id;
1422   // Dimension within feature column.
1423   int32 feature_dim;
1424   // bucketized feature value .
1425   int32 bucket_id;
1426 };
1427 
1428 typedef std::map<StatsPartitionKey, std::vector<float>, StatsPartitionKey::Less>
1429     StatsPartitionMap;
1430 typedef StatsPartitionMap::iterator StatsPartitionIterator;
1431 
1432 // Key based on instance and feature dimension.
1433 struct InstanceFeatureDimKey {
InstanceFeatureDimKeytensorflow::InstanceFeatureDimKey1434   InstanceFeatureDimKey() : instance(-1), feature_dim(-1) {}
1435 
InstanceFeatureDimKeytensorflow::InstanceFeatureDimKey1436   InstanceFeatureDimKey(const int32 instance, const int32 feature_dim)
1437       : instance(instance), feature_dim(feature_dim) {}
1438 
operator ==tensorflow::InstanceFeatureDimKey1439   bool operator==(const InstanceFeatureDimKey& other) const {
1440     return (instance == other.instance) && (feature_dim == other.feature_dim);
1441   }
1442 
1443   // Compare for InstanceFeatureDimKey.
1444   struct Less {
operator ()tensorflow::InstanceFeatureDimKey::Less1445     bool operator()(const InstanceFeatureDimKey& a,
1446                     const InstanceFeatureDimKey& b) const {
1447       if (a.instance < b.instance) {
1448         return true;
1449       }
1450       if ((a.instance == b.instance) && (a.feature_dim < b.feature_dim)) {
1451         return true;
1452       }
1453       return false;
1454     }
1455   };
1456 
1457   // Instance id within a batch.
1458   int32 instance;
1459   // Dimension within feature column.
1460   int32 feature_dim;
1461 };
1462 
1463 // Add statistics to StatsPartitionMap for (instance, feature dim, bucket id).
AddInstanceStatsToMap(const int32 instance,const int32 feature_dim,const int32 bucket_id,const int32 logits_dims,const int32 stats_dims,StatsPartitionMap * stats_map,const TTypes<float>::ConstMatrix & gradients,const TTypes<float>::ConstMatrix & hessians,const TTypes<int32>::ConstVec & node_ids)1464 static void AddInstanceStatsToMap(const int32 instance, const int32 feature_dim,
1465                                   const int32 bucket_id,
1466                                   const int32 logits_dims,
1467                                   const int32 stats_dims,
1468                                   StatsPartitionMap* stats_map,
1469                                   const TTypes<float>::ConstMatrix& gradients,
1470                                   const TTypes<float>::ConstMatrix& hessians,
1471                                   const TTypes<int32>::ConstVec& node_ids) {
1472   const int32 node_id = node_ids(instance);
1473   const auto key = StatsPartitionKey(node_id, feature_dim, bucket_id);
1474   std::pair<StatsPartitionIterator, bool> const& insert_result =
1475       stats_map->insert(StatsPartitionIterator::value_type(
1476           key, std::vector<float>(stats_dims, 0.0f)));
1477   auto& stats = insert_result.first->second;
1478   for (int stat_dim = 0; stat_dim < logits_dims; ++stat_dim) {
1479     stats[stat_dim] += gradients(instance, stat_dim);
1480   }
1481   for (int stat_dim = logits_dims; stat_dim < stats_dims; ++stat_dim) {
1482     stats[stat_dim] += hessians(instance, stat_dim - logits_dims);
1483   }
1484 }
1485 
1486 // Add statistics to StatsPartitionMap for bucket_id ranging from
1487 // (start_instance, start_feature_dim) to (end_instance, end_feature_dim),
1488 // inclusive on start and end instances, exclusive on end feature dim.
AddRangeStats(const int start_instance,const int end_instance,const int start_feature_dim,const int end_feature_dim,StatsPartitionMap * stats_map,const TTypes<float>::ConstMatrix & gradients,const TTypes<float>::ConstMatrix & hessians,const TTypes<int32>::ConstVec & node_ids,const int32 feature_dims,const int32 bucket_id,const int32 logits_dims,const int32 stats_dims)1489 static void AddRangeStats(const int start_instance, const int end_instance,
1490                           const int start_feature_dim,
1491                           const int end_feature_dim,
1492                           StatsPartitionMap* stats_map,
1493                           const TTypes<float>::ConstMatrix& gradients,
1494                           const TTypes<float>::ConstMatrix& hessians,
1495                           const TTypes<int32>::ConstVec& node_ids,
1496                           const int32 feature_dims, const int32 bucket_id,
1497                           const int32 logits_dims, const int32 stats_dims) {
1498   DCHECK_LE(start_instance, end_instance);
1499   if (start_instance == end_instance) {
1500     DCHECK_LT(start_feature_dim, end_feature_dim);
1501   }
1502   for (int32 instance = start_instance; instance <= end_instance; ++instance) {
1503     const int32 start_f_dim =
1504         (instance == start_instance) ? start_feature_dim + 1 : 0;
1505     const int32 end_f_dim =
1506         (instance == end_instance) ? end_feature_dim : feature_dims;
1507     for (int32 f_dim = start_f_dim; f_dim < end_f_dim; ++f_dim) {
1508       AddInstanceStatsToMap(instance, f_dim, bucket_id, logits_dims, stats_dims,
1509                             stats_map, gradients, hessians, node_ids);
1510     }
1511   }
1512 }
1513 
1514 class BoostedTreesSparseAggregateStatsOp : public OpKernel {
1515  public:
BoostedTreesSparseAggregateStatsOp(OpKernelConstruction * const context)1516   explicit BoostedTreesSparseAggregateStatsOp(
1517       OpKernelConstruction* const context)
1518       : OpKernel(context) {
1519     OP_REQUIRES_OK(context, context->GetAttr("max_splits", &max_splits_));
1520     OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
1521   }
1522 
Compute(OpKernelContext * const context)1523   void Compute(OpKernelContext* const context) override {
1524     // node_ids.
1525     const Tensor* node_ids_t;
1526     OP_REQUIRES_OK(context, context->input("node_ids", &node_ids_t));
1527     const auto node_ids = node_ids_t->vec<int32>();
1528 
1529     // gradients.
1530     const Tensor* gradients_t;
1531     OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
1532     const auto gradients = gradients_t->matrix<float>();
1533 
1534     // hessians.
1535     const Tensor* hessians_t;
1536     OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
1537     const auto hessians = hessians_t->matrix<float>();
1538 
1539     // feature indices.
1540     const Tensor* feature_indices_t;
1541     OP_REQUIRES_OK(context,
1542                    context->input("feature_indices", &feature_indices_t));
1543     const auto feature_indices = feature_indices_t->matrix<int32>();
1544 
1545     // feature values.
1546     const Tensor* feature_values_t;
1547     OP_REQUIRES_OK(context,
1548                    context->input("feature_values", &feature_values_t));
1549     const auto feature_values = feature_values_t->vec<int32>();
1550 
1551     // feature shape.
1552     const Tensor* feature_shape_t;
1553     OP_REQUIRES_OK(context, context->input("feature_shape", &feature_shape_t));
1554     OP_REQUIRES(context, TensorShapeUtils::IsVector(feature_shape_t->shape()),
1555                 errors::InvalidArgument(
1556                     "Input shapes should be a vector but received shapes ",
1557                     feature_shape_t->shape().DebugString()));
1558     const auto feature_shape = feature_shape_t->vec<int32>();
1559 
1560     const int64 batch_size = gradients_t->dim_size(0);
1561     const int64 logits_dims = gradients_t->dim_size(1);
1562     const int64 hessians_dims = hessians_t->dim_size(1);
1563     const int64 stats_dims = logits_dims + hessians_dims;
1564     const int64 num_sparse_entries = feature_indices_t->dim_size(0);
1565     const int32 feature_dims = feature_shape(1);
1566     DCHECK_LE(num_sparse_entries, batch_size * feature_dims);
1567 
1568     // Aggregate statistics info to map.
1569     StatsPartitionMap stats_map;
1570 
1571     int prev_instance = 0;
1572     int prev_f_dim = -1;
1573 
1574     for (int i = 0; i < num_sparse_entries; ++i) {
1575       // the instance number within a batch
1576       const int32 instance = feature_indices(i, 0);
1577       DCHECK_LE(instance, batch_size);
1578       DCHECK_GE(instance, prev_instance);
1579       // the node id within a tree.
1580       const int32 node_id = node_ids(instance);
1581       DCHECK_LE(node_id, max_splits_);
1582       // the feature dimension.
1583       const int32 f_dim = feature_indices(i, 1);
1584       DCHECK_LE(f_dim, feature_dims);
1585       // the bucket id of the value.
1586       const int32 bucket_id = feature_values(i);
1587       DCHECK_LE(bucket_id, num_buckets_);
1588 
1589       // Add statistics for the missing entries into default bucket.
1590       // The last bucket is default bucket.
1591       const int missing_entry_bucket = num_buckets_;
1592       AddRangeStats(prev_instance, instance, prev_f_dim, f_dim, &stats_map,
1593                     gradients, hessians, node_ids, feature_dims,
1594                     missing_entry_bucket, logits_dims, stats_dims);
1595       prev_instance = instance;
1596       prev_f_dim = f_dim;
1597       // Add statistics for the non-missing entry into
1598       // (cur_instance, cur_f_dim, bucket_id).
1599       AddInstanceStatsToMap(instance, f_dim, bucket_id, logits_dims, stats_dims,
1600                             &stats_map, gradients, hessians, node_ids);
1601     }
1602     AddRangeStats(prev_instance, batch_size - 1, prev_f_dim, feature_dims,
1603                   &stats_map, gradients, hessians, node_ids, feature_dims,
1604                   num_buckets_, logits_dims, stats_dims);
1605 
1606     // Serialize statistics info map to tensor output.
1607     const int64 num_slots = stats_map.size() * stats_dims;
1608     Tensor* summary_indices_t = nullptr;
1609     OP_REQUIRES_OK(context,
1610                    context->allocate_output("stats_summary_indices",
1611                                             TensorShape({num_slots, 4}),
1612                                             &summary_indices_t));
1613     auto summary_indices = summary_indices_t->matrix<int32>();
1614     Tensor* summary_values_t = nullptr;
1615     OP_REQUIRES_OK(context, context->allocate_output("stats_summary_values",
1616                                                      TensorShape({num_slots}),
1617                                                      &summary_values_t));
1618     auto summary_values = summary_values_t->vec<float>();
1619     int entry_index = 0;
1620     for (auto& iter : stats_map) {
1621       for (int stat_dim = 0; stat_dim < stats_dims; ++stat_dim) {
1622         summary_indices(entry_index, 0) = iter.first.node_id;
1623         summary_indices(entry_index, 1) = iter.first.feature_dim;
1624         summary_indices(entry_index, 2) = iter.first.bucket_id;
1625         summary_indices(entry_index, 3) = stat_dim;
1626         summary_values(entry_index) = iter.second[stat_dim];
1627         ++entry_index;
1628       }
1629     }
1630 
1631     Tensor* summary_shape_t = nullptr;
1632     OP_REQUIRES_OK(
1633         context, context->allocate_output("stats_summary_shape",
1634                                           TensorShape({4}), &summary_shape_t));
1635     auto summary_shape = summary_shape_t->vec<int32>();
1636     summary_shape(0) = max_splits_;
1637     summary_shape(1) = feature_dims;
1638     summary_shape(2) = num_buckets_ + 1;
1639     summary_shape(3) = stats_dims;
1640   }
1641 
1642  private:
1643   int max_splits_;
1644   int num_buckets_;
1645 };
1646 
1647 REGISTER_KERNEL_BUILDER(
1648     Name("BoostedTreesSparseAggregateStats").Device(DEVICE_CPU),
1649     BoostedTreesSparseAggregateStatsOp);
1650 
1651 }  // namespace tensorflow
1652