• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "third_party/eigen3/Eigen/Core"
17 #include "tensorflow/core/framework/op_kernel.h"
18 #include "tensorflow/core/framework/tensor_shape.h"
19 #include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
20 #include "tensorflow/core/kernels/boosted_trees/resources.h"
21 #include "tensorflow/core/kernels/boosted_trees/tree_helper.h"
22 #include "tensorflow/core/lib/core/refcount.h"
23 
24 namespace tensorflow {
25 
26 namespace {
27 constexpr float kLayerByLayerTreeWeight = 1.0;
28 constexpr float kMinDeltaForCenterBias = 0.01;
29 
30 enum PruningMode { kNoPruning = 0, kPrePruning = 1, kPostPruning = 2 };
31 
32 }  // namespace
33 
34 class BoostedTreesUpdateEnsembleOp : public OpKernel {
35  public:
BoostedTreesUpdateEnsembleOp(OpKernelConstruction * const context)36   explicit BoostedTreesUpdateEnsembleOp(OpKernelConstruction* const context)
37       : OpKernel(context) {
38     OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_));
39 
40     int32 pruning_index;
41     OP_REQUIRES_OK(context, context->GetAttr("pruning_mode", &pruning_index));
42     pruning_mode_ = static_cast<PruningMode>(pruning_index);
43   }
44 
Compute(OpKernelContext * const context)45   void Compute(OpKernelContext* const context) override {
46     // Get decision tree ensemble.
47     core::RefCountPtr<BoostedTreesEnsembleResource> ensemble_resource;
48     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
49                                            &ensemble_resource));
50     mutex_lock l(*ensemble_resource->get_mutex());
51     // Increase the ensemble stamp.
52     ensemble_resource->set_stamp(ensemble_resource->stamp() + 1);
53 
54     // Read node ids, gains, thresholds and node contribs.
55     OpInputList node_ids_list;
56     OpInputList gains_list;
57     OpInputList thresholds_list;
58     OpInputList left_node_contribs;
59     OpInputList right_node_contribs;
60     OP_REQUIRES_OK(context, context->input_list("node_ids", &node_ids_list));
61     OP_REQUIRES_OK(context, context->input_list("gains", &gains_list));
62     OP_REQUIRES_OK(context,
63                    context->input_list("thresholds", &thresholds_list));
64     OP_REQUIRES_OK(context, context->input_list("left_node_contribs",
65                                                 &left_node_contribs));
66     OP_REQUIRES_OK(context, context->input_list("right_node_contribs",
67                                                 &right_node_contribs));
68 
69     const Tensor* feature_ids_t;
70     OP_REQUIRES_OK(context, context->input("feature_ids", &feature_ids_t));
71     const auto feature_ids = feature_ids_t->vec<int32>();
72 
73     const Tensor* max_depth_t;
74     OP_REQUIRES_OK(context, context->input("max_depth", &max_depth_t));
75     const auto max_depth = max_depth_t->scalar<int32>()();
76 
77     const Tensor* learning_rate_t;
78     OP_REQUIRES_OK(context, context->input("learning_rate", &learning_rate_t));
79     const auto learning_rate = learning_rate_t->scalar<float>()();
80     // Op does not support multi-class, the V2 op below does however.
81     int32 logits_dimension = 1;
82     // Find best splits for each active node.
83     std::map<int32, boosted_trees::SplitCandidate> best_splits;
84     FindBestSplitsPerNode(context, learning_rate, node_ids_list, gains_list,
85                           thresholds_list, left_node_contribs,
86                           right_node_contribs, feature_ids, &best_splits);
87 
88     int32 current_tree =
89         UpdateGlobalAttemptsAndRetrieveGrowableTree(ensemble_resource);
90 
91     // No-op if no new splits can be considered.
92     if (best_splits.empty()) {
93       LOG(WARNING) << "Not growing tree ensemble as no good splits were found.";
94       return;
95     }
96 
97     const int32 new_num_layers =
98         ensemble_resource->GetNumLayersGrown(current_tree) + 1;
99     VLOG(1) << "Adding layer #" << new_num_layers - 1 << " to tree #"
100             << current_tree << " of ensemble of " << current_tree + 1
101             << " trees.";
102     bool split_happened = false;
103     int32 node_id_start = ensemble_resource->GetNumNodes(current_tree);
104     // Add the splits to the tree.
105     for (auto& split_entry : best_splits) {
106       const float gain = split_entry.second.gain;
107       if (pruning_mode_ == kPrePruning) {
108         // Don't consider negative splits if we're pre-pruning the tree.
109         // Note that zero-gain splits are acceptable.
110         if (gain < 0) {
111           continue;
112         }
113       }
114 
115       // unused.
116       int32 left_node_id;
117       int32 right_node_id;
118 
119       ensemble_resource->AddBucketizedSplitNode(current_tree, split_entry,
120                                                 logits_dimension, &left_node_id,
121                                                 &right_node_id);
122       split_happened = true;
123     }
124     int32 node_id_end = ensemble_resource->GetNumNodes(current_tree);
125     if (split_happened) {
126       // Update growable tree metadata.
127       ensemble_resource->SetNumLayersGrown(current_tree, new_num_layers);
128       // Finalize the tree if needed.
129       if (ensemble_resource->GetNumLayersGrown(current_tree) >= max_depth) {
130         // If the tree is finalized, next growing will start from node 0;
131         node_id_start = 0;
132         node_id_end = 1;
133         ensemble_resource->SetIsFinalized(current_tree, true);
134         if (pruning_mode_ == kPostPruning) {
135           ensemble_resource->PostPruneTree(current_tree, logits_dimension);
136         }
137         if (ensemble_resource->num_trees() > 0) {
138           // Create a dummy new tree with an empty node.
139           ensemble_resource->AddNewTree(kLayerByLayerTreeWeight, 1);
140         }
141       }
142       // If we managed to split, update the node range. If we didn't, don't
143       // update as we will try to split the same nodes with new instances.
144       ensemble_resource->UpdateLastLayerNodesRange(node_id_start, node_id_end);
145     }
146   }
147 
148  private:
UpdateGlobalAttemptsAndRetrieveGrowableTree(const core::RefCountPtr<BoostedTreesEnsembleResource> & resource)149   int32 UpdateGlobalAttemptsAndRetrieveGrowableTree(
150       const core::RefCountPtr<BoostedTreesEnsembleResource>& resource) {
151     int32 num_trees = resource->num_trees();
152     int32 current_tree = num_trees - 1;
153 
154     // Increment global attempt stats.
155     resource->UpdateGrowingMetadata();
156 
157     // Note we don't set tree weight to be equal to learning rate, since we
158     // apply learning rate to leaf weights instead, when doing layer-by-layer
159     // boosting.
160     if (num_trees <= 0) {
161       // Create a new tree with a no-op leaf.
162       current_tree = resource->AddNewTree(kLayerByLayerTreeWeight, 1);
163     }
164     return current_tree;
165   }
166 
167   // Helper method which effectively does a reduce over all split candidates
168   // and finds the best split for each node.
FindBestSplitsPerNode(OpKernelContext * const context,const float learning_rate,const OpInputList & node_ids_list,const OpInputList & gains_list,const OpInputList & thresholds_list,const OpInputList & left_node_contribs_list,const OpInputList & right_node_contribs_list,const TTypes<const int32>::Vec & feature_ids,std::map<int32,boosted_trees::SplitCandidate> * best_split_per_node)169   void FindBestSplitsPerNode(
170       OpKernelContext* const context, const float learning_rate,
171       const OpInputList& node_ids_list, const OpInputList& gains_list,
172       const OpInputList& thresholds_list,
173       const OpInputList& left_node_contribs_list,
174       const OpInputList& right_node_contribs_list,
175       const TTypes<const int32>::Vec& feature_ids,
176       std::map<int32, boosted_trees::SplitCandidate>* best_split_per_node) {
177     // Find best split per node going through every feature candidate.
178     for (int64 feature_idx = 0; feature_idx < num_features_; ++feature_idx) {
179       const auto& node_ids = node_ids_list[feature_idx].vec<int32>();
180       const auto& gains = gains_list[feature_idx].vec<float>();
181       const auto& thresholds = thresholds_list[feature_idx].vec<int32>();
182       const auto& left_node_contribs =
183           left_node_contribs_list[feature_idx].matrix<float>();
184       const auto& right_node_contribs =
185           right_node_contribs_list[feature_idx].matrix<float>();
186 
187       for (size_t candidate_idx = 0; candidate_idx < node_ids.size();
188            ++candidate_idx) {
189         // Get current split candidate.
190         const auto& node_id = node_ids(candidate_idx);
191         const auto& gain = gains(candidate_idx);
192         const auto& best_split_it = best_split_per_node->find(node_id);
193         boosted_trees::SplitCandidate candidate;
194         candidate.feature_id = feature_ids(feature_idx);
195         candidate.candidate_idx = candidate_idx;
196         candidate.gain = gain;
197         candidate.dimension_id = 0;
198         candidate.threshold = thresholds(candidate_idx);
199         candidate.left_node_contribs.push_back(
200             learning_rate * left_node_contribs(candidate_idx, 0));
201         candidate.right_node_contribs.push_back(
202             learning_rate * right_node_contribs(candidate_idx, 0));
203         candidate.split_type = boosted_trees::SplitTypeWithDefault_Name(
204             boosted_trees::INEQUALITY_DEFAULT_LEFT);
205 
206         if (TF_PREDICT_FALSE(best_split_it != best_split_per_node->end() &&
207                              GainsAreEqual(gain, best_split_it->second.gain))) {
208           const auto best_candidate = (*best_split_per_node)[node_id];
209           const int32 best_feature_id = best_candidate.feature_id;
210           const int32 feature_id = candidate.feature_id;
211           VLOG(2) << "Breaking ties on feature ids and buckets";
212           // Breaking ties deterministically.
213           if (feature_id < best_feature_id) {
214             (*best_split_per_node)[node_id] = candidate;
215           }
216         } else if (best_split_it == best_split_per_node->end() ||
217                    GainIsLarger(gain, best_split_it->second.gain)) {
218           (*best_split_per_node)[node_id] = candidate;
219         }
220       }
221     }
222   }
223 
224  private:
225   int32 num_features_;
226   PruningMode pruning_mode_;
227 };
228 
229 REGISTER_KERNEL_BUILDER(Name("BoostedTreesUpdateEnsemble").Device(DEVICE_CPU),
230                         BoostedTreesUpdateEnsembleOp);
231 
232 // V2 of UpdateEnsembleOp that takes in split type and feature dimension id.
233 class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
234  public:
BoostedTreesUpdateEnsembleV2Op(OpKernelConstruction * const context)235   explicit BoostedTreesUpdateEnsembleV2Op(OpKernelConstruction* const context)
236       : OpKernel(context) {
237     OP_REQUIRES_OK(context, context->GetAttr("logits_dimension", &logits_dim_));
238     OP_REQUIRES_OK(context, context->GetAttr("num_groups", &num_groups_));
239   }
240 
Compute(OpKernelContext * const context)241   void Compute(OpKernelContext* const context) override {
242     // Get decision tree ensemble.
243     core::RefCountPtr<BoostedTreesEnsembleResource> ensemble_resource;
244     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
245                                            &ensemble_resource));
246     mutex_lock l(*ensemble_resource->get_mutex());
247     // Increase the ensemble stamp.
248     ensemble_resource->set_stamp(ensemble_resource->stamp() + 1);
249 
250     // Read node ids, gains, thresholds and node contribs.
251     OpInputList node_ids_list;
252     OpInputList gains_list;
253     OpInputList thresholds_list;
254     OpInputList dimension_ids_list;
255     OpInputList left_node_contribs_list;
256     OpInputList right_node_contribs_list;
257     OpInputList split_types_list;
258     OP_REQUIRES_OK(context, context->input_list("node_ids", &node_ids_list));
259     OP_REQUIRES_OK(context, context->input_list("gains", &gains_list));
260     OP_REQUIRES_OK(context,
261                    context->input_list("thresholds", &thresholds_list));
262     OP_REQUIRES_OK(context,
263                    context->input_list("dimension_ids", &dimension_ids_list));
264     OP_REQUIRES_OK(context, context->input_list("left_node_contribs",
265                                                 &left_node_contribs_list));
266     OP_REQUIRES_OK(context, context->input_list("right_node_contribs",
267                                                 &right_node_contribs_list));
268     OP_REQUIRES_OK(context,
269                    context->input_list("split_types", &split_types_list));
270 
271     OpInputList feature_ids_list;
272     OP_REQUIRES_OK(context,
273                    context->input_list("feature_ids", &feature_ids_list));
274 
275     const Tensor* max_depth_t;
276     OP_REQUIRES_OK(context, context->input("max_depth", &max_depth_t));
277     const auto max_depth = max_depth_t->scalar<int32>()();
278 
279     const Tensor* learning_rate_t;
280     OP_REQUIRES_OK(context, context->input("learning_rate", &learning_rate_t));
281     const auto learning_rate = learning_rate_t->scalar<float>()();
282 
283     const Tensor* pruning_mode_t;
284     OP_REQUIRES_OK(context, context->input("pruning_mode", &pruning_mode_t));
285     const auto pruning_mode =
286         static_cast<PruningMode>(pruning_mode_t->scalar<int32>()());
287     // Find best splits for each active node.
288     std::map<int32, boosted_trees::SplitCandidate> best_splits;
289     FindBestSplitsPerNode(context, learning_rate, node_ids_list, gains_list,
290                           thresholds_list, dimension_ids_list,
291                           left_node_contribs_list, right_node_contribs_list,
292                           split_types_list, feature_ids_list, &best_splits);
293 
294     int32 current_tree =
295         UpdateGlobalAttemptsAndRetrieveGrowableTree(ensemble_resource);
296 
297     // No-op if no new splits can be considered.
298     if (best_splits.empty()) {
299       LOG(WARNING) << "Not growing tree ensemble as no good splits were found.";
300       return;
301     }
302 
303     const int32 new_num_layers =
304         ensemble_resource->GetNumLayersGrown(current_tree) + 1;
305     VLOG(1) << "Adding layer #" << new_num_layers - 1 << " to tree #"
306             << current_tree << " of ensemble of " << current_tree + 1
307             << " trees.";
308     bool split_happened = false;
309     int32 node_id_start = ensemble_resource->GetNumNodes(current_tree);
310     // Add the splits to the tree.
311     for (auto& split_entry : best_splits) {
312       const float gain = split_entry.second.gain;
313       const string split_type = split_entry.second.split_type;
314 
315       if (pruning_mode == kPrePruning) {
316         // Don't consider negative splits if we're pre-pruning the tree.
317         // Note that zero-gain splits are acceptable.
318         if (gain < 0) {
319           continue;
320         }
321       }
322 
323       // unused.
324       int32 left_node_id;
325       int32 right_node_id;
326 
327       boosted_trees::SplitTypeWithDefault split_type_with_default;
328       bool parsed = boosted_trees::SplitTypeWithDefault_Parse(
329           split_type, &split_type_with_default);
330       DCHECK(parsed);
331       if (split_type_with_default == boosted_trees::EQUALITY_DEFAULT_RIGHT) {
332         // Add equality split to the node.
333         ensemble_resource->AddCategoricalSplitNode(current_tree, split_entry,
334                                                    logits_dim_, &left_node_id,
335                                                    &right_node_id);
336       } else {
337         // Add inequality split to the node.
338         ensemble_resource->AddBucketizedSplitNode(current_tree, split_entry,
339                                                   logits_dim_, &left_node_id,
340                                                   &right_node_id);
341       }
342       split_happened = true;
343     }
344     int32 node_id_end = ensemble_resource->GetNumNodes(current_tree);
345     if (split_happened) {
346       // Update growable tree metadata.
347       ensemble_resource->SetNumLayersGrown(current_tree, new_num_layers);
348       // Finalize the tree if needed.
349       if (ensemble_resource->GetNumLayersGrown(current_tree) >= max_depth) {
350         // If the tree is finalized, next growing will start from node 0;
351         node_id_start = 0;
352         node_id_end = 1;
353         ensemble_resource->SetIsFinalized(current_tree, true);
354         if (pruning_mode == kPostPruning) {
355           ensemble_resource->PostPruneTree(current_tree, logits_dim_);
356         }
357         if (ensemble_resource->num_trees() > 0) {
358           // Create a dummy new tree with an empty node.
359           ensemble_resource->AddNewTree(kLayerByLayerTreeWeight, logits_dim_);
360         }
361       }
362       // If we managed to split, update the node range. If we didn't, don't
363       // update as we will try to split the same nodes with new instances.
364       ensemble_resource->UpdateLastLayerNodesRange(node_id_start, node_id_end);
365     }
366   }
367 
368  private:
UpdateGlobalAttemptsAndRetrieveGrowableTree(const core::RefCountPtr<BoostedTreesEnsembleResource> & resource)369   int32 UpdateGlobalAttemptsAndRetrieveGrowableTree(
370       const core::RefCountPtr<BoostedTreesEnsembleResource>& resource) {
371     int32 num_trees = resource->num_trees();
372     int32 current_tree = num_trees - 1;
373 
374     // Increment global attempt stats.
375     resource->UpdateGrowingMetadata();
376 
377     // Note we don't set tree weight to be equal to learning rate, since we
378     // apply learning rate to leaf weights instead, when doing layer-by-layer
379     // boosting.
380     if (num_trees <= 0) {
381       // Create a new tree with a no-op leaf.
382       current_tree = resource->AddNewTree(kLayerByLayerTreeWeight, logits_dim_);
383     }
384     return current_tree;
385   }
386 
387   // Helper method which effectively does a reduce over all split candidates
388   // and finds the best split for each node.
FindBestSplitsPerNode(OpKernelContext * const context,const float learning_rate,const OpInputList & node_ids_list,const OpInputList & gains_list,const OpInputList & thresholds_list,const OpInputList & dimension_ids_list,const OpInputList & left_node_contribs_list,const OpInputList & right_node_contribs_list,const OpInputList & split_types_list,const OpInputList & feature_ids_list,std::map<int32,boosted_trees::SplitCandidate> * best_split_per_node)389   void FindBestSplitsPerNode(
390       OpKernelContext* const context, const float learning_rate,
391       const OpInputList& node_ids_list, const OpInputList& gains_list,
392       const OpInputList& thresholds_list, const OpInputList& dimension_ids_list,
393       const OpInputList& left_node_contribs_list,
394       const OpInputList& right_node_contribs_list,
395       const OpInputList& split_types_list, const OpInputList& feature_ids_list,
396       std::map<int32, boosted_trees::SplitCandidate>* best_split_per_node) {
397     // Find best split per node going through every feature candidate.
398     for (int64 group_idx = 0; group_idx < num_groups_; ++group_idx) {
399       const auto& node_ids = node_ids_list[group_idx].vec<int32>();
400       const auto& gains = gains_list[group_idx].vec<float>();
401       const auto& feature_ids = feature_ids_list[group_idx].vec<int32>();
402       const auto& thresholds = thresholds_list[group_idx].vec<int32>();
403       const auto& dimension_ids = dimension_ids_list[group_idx].vec<int32>();
404       const auto& left_node_contribs =
405           left_node_contribs_list[group_idx].matrix<float>();
406       const auto& right_node_contribs =
407           right_node_contribs_list[group_idx].matrix<float>();
408       const auto& split_types = split_types_list[group_idx].vec<tstring>();
409 
410       for (size_t candidate_idx = 0; candidate_idx < node_ids.size();
411            ++candidate_idx) {
412         // Get current split candidate.
413         const auto& node_id = node_ids(candidate_idx);
414         const auto& gain = gains(candidate_idx);
415         const auto& feature_id = feature_ids(candidate_idx);
416 
417         auto best_split_it = best_split_per_node->find(node_id);
418         boosted_trees::SplitCandidate candidate;
419         candidate.candidate_idx = candidate_idx;
420         candidate.gain = gain;
421         candidate.feature_id = feature_id;
422         candidate.threshold = thresholds(candidate_idx);
423         candidate.dimension_id = dimension_ids(candidate_idx);
424         candidate.split_type = split_types(candidate_idx);
425         for (int i = 0; i < logits_dim_; ++i) {
426           candidate.left_node_contribs.push_back(
427               learning_rate * left_node_contribs(candidate_idx, i));
428           candidate.right_node_contribs.push_back(
429               learning_rate * right_node_contribs(candidate_idx, i));
430         }
431         if (TF_PREDICT_FALSE(best_split_it != best_split_per_node->end() &&
432                              GainsAreEqual(gain, best_split_it->second.gain))) {
433           const auto& best_candidate = (*best_split_per_node)[node_id];
434           const int32 best_feature_id = best_candidate.feature_id;
435           const int32 feature_id = candidate.feature_id;
436           VLOG(2) << "Breaking ties on feature ids and buckets";
437           // Breaking ties deterministically.
438           if (feature_id < best_feature_id) {
439             (*best_split_per_node)[node_id] = candidate;
440           }
441         } else if (best_split_it == best_split_per_node->end() ||
442                    GainIsLarger(gain, best_split_it->second.gain)) {
443           (*best_split_per_node)[node_id] = candidate;
444         }
445       }
446     }
447   }
448 
449  private:
450   int32 logits_dim_;
451   int32 num_groups_;
452 };
453 
454 REGISTER_KERNEL_BUILDER(Name("BoostedTreesUpdateEnsembleV2").Device(DEVICE_CPU),
455                         BoostedTreesUpdateEnsembleV2Op);
456 
457 class BoostedTreesCenterBiasOp : public OpKernel {
458  public:
BoostedTreesCenterBiasOp(OpKernelConstruction * const context)459   explicit BoostedTreesCenterBiasOp(OpKernelConstruction* const context)
460       : OpKernel(context) {}
461 
Compute(OpKernelContext * const context)462   void Compute(OpKernelContext* const context) override {
463     // Get decision tree ensemble.
464     core::RefCountPtr<BoostedTreesEnsembleResource> ensemble_resource;
465     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
466                                            &ensemble_resource));
467     mutex_lock l(*ensemble_resource->get_mutex());
468     // Increase the ensemble stamp.
469     ensemble_resource->set_stamp(ensemble_resource->stamp() + 1);
470 
471     // Read means of hessians and gradients
472     const Tensor* mean_gradients_t;
473     OP_REQUIRES_OK(context,
474                    context->input("mean_gradients", &mean_gradients_t));
475     const int32 logits_dim = mean_gradients_t->dim_size(1);
476     const Tensor* mean_hessians_t;
477     OP_REQUIRES_OK(context, context->input("mean_hessians", &mean_hessians_t));
478 
479     // Get the regularization options.
480     const Tensor* l1_t;
481     OP_REQUIRES_OK(context, context->input("l1", &l1_t));
482     const auto l1 = l1_t->scalar<float>()();
483     const Tensor* l2_t;
484     OP_REQUIRES_OK(context, context->input("l2", &l2_t));
485     const auto l2 = l2_t->scalar<float>()();
486 
487     // For now, assume 1-dimensional weight on leaves.
488     Eigen::VectorXf logits_vector(1);
489     float unused_gain;
490 
491     // TODO(crawles): Support multiclass.
492     DCHECK_EQ(logits_dim, 1);
493     Eigen::VectorXf gradients_mean(1);
494     Eigen::VectorXf hessians_mean(1);
495     gradients_mean[0] = mean_gradients_t->flat<float>()(0);
496     hessians_mean[0] = mean_hessians_t->flat<float>()(0);
497     CalculateWeightsAndGains(gradients_mean, hessians_mean, l1, l2,
498                              &logits_vector, &unused_gain);
499     const float logits = logits_vector[0];
500 
501     float current_bias = 0.0;
502     bool continue_centering = true;
503     if (ensemble_resource->num_trees() == 0) {
504       ensemble_resource->AddNewTreeWithLogits(kLayerByLayerTreeWeight, {logits},
505                                               1);
506       current_bias = logits;
507     } else {
508       const auto& current_biases = ensemble_resource->node_value(0, 0);
509       DCHECK_EQ(current_biases.size(), 1);
510       current_bias = current_biases[0];
511       continue_centering =
512           std::abs(logits / current_bias) > kMinDeltaForCenterBias;
513       current_bias += logits;
514       ensemble_resource->set_node_value(0, 0, current_bias);
515     }
516 
517     Tensor* continue_centering_t = nullptr;
518     OP_REQUIRES_OK(
519         context, context->allocate_output("continue_centering", TensorShape({}),
520                                           &continue_centering_t));
521     // Check if we need to continue centering bias.
522     continue_centering_t->scalar<bool>()() = continue_centering;
523   }
524 };
525 REGISTER_KERNEL_BUILDER(Name("BoostedTreesCenterBias").Device(DEVICE_CPU),
526                         BoostedTreesCenterBiasOp);
527 
528 }  // namespace tensorflow
529