1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "third_party/eigen3/Eigen/Core" 17 #include "tensorflow/core/framework/op_kernel.h" 18 #include "tensorflow/core/framework/tensor_shape.h" 19 #include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h" 20 #include "tensorflow/core/kernels/boosted_trees/resources.h" 21 #include "tensorflow/core/kernels/boosted_trees/tree_helper.h" 22 #include "tensorflow/core/lib/core/refcount.h" 23 24 namespace tensorflow { 25 26 namespace { 27 constexpr float kLayerByLayerTreeWeight = 1.0; 28 constexpr float kMinDeltaForCenterBias = 0.01; 29 30 enum PruningMode { kNoPruning = 0, kPrePruning = 1, kPostPruning = 2 }; 31 32 } // namespace 33 34 class BoostedTreesUpdateEnsembleOp : public OpKernel { 35 public: BoostedTreesUpdateEnsembleOp(OpKernelConstruction * const context)36 explicit BoostedTreesUpdateEnsembleOp(OpKernelConstruction* const context) 37 : OpKernel(context) { 38 OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_)); 39 40 int32 pruning_index; 41 OP_REQUIRES_OK(context, context->GetAttr("pruning_mode", &pruning_index)); 42 pruning_mode_ = static_cast<PruningMode>(pruning_index); 43 } 44 Compute(OpKernelContext * const context)45 void Compute(OpKernelContext* const context) override { 46 // Get decision tree ensemble. 47 core::RefCountPtr<BoostedTreesEnsembleResource> ensemble_resource; 48 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 49 &ensemble_resource)); 50 mutex_lock l(*ensemble_resource->get_mutex()); 51 // Increase the ensemble stamp. 52 ensemble_resource->set_stamp(ensemble_resource->stamp() + 1); 53 54 // Read node ids, gains, thresholds and node contribs. 55 OpInputList node_ids_list; 56 OpInputList gains_list; 57 OpInputList thresholds_list; 58 OpInputList left_node_contribs; 59 OpInputList right_node_contribs; 60 OP_REQUIRES_OK(context, context->input_list("node_ids", &node_ids_list)); 61 OP_REQUIRES_OK(context, context->input_list("gains", &gains_list)); 62 OP_REQUIRES_OK(context, 63 context->input_list("thresholds", &thresholds_list)); 64 OP_REQUIRES_OK(context, context->input_list("left_node_contribs", 65 &left_node_contribs)); 66 OP_REQUIRES_OK(context, context->input_list("right_node_contribs", 67 &right_node_contribs)); 68 69 const Tensor* feature_ids_t; 70 OP_REQUIRES_OK(context, context->input("feature_ids", &feature_ids_t)); 71 const auto feature_ids = feature_ids_t->vec<int32>(); 72 73 const Tensor* max_depth_t; 74 OP_REQUIRES_OK(context, context->input("max_depth", &max_depth_t)); 75 const auto max_depth = max_depth_t->scalar<int32>()(); 76 77 const Tensor* learning_rate_t; 78 OP_REQUIRES_OK(context, context->input("learning_rate", &learning_rate_t)); 79 const auto learning_rate = learning_rate_t->scalar<float>()(); 80 // Op does not support multi-class, the V2 op below does however. 81 int32 logits_dimension = 1; 82 // Find best splits for each active node. 83 std::map<int32, boosted_trees::SplitCandidate> best_splits; 84 FindBestSplitsPerNode(context, learning_rate, node_ids_list, gains_list, 85 thresholds_list, left_node_contribs, 86 right_node_contribs, feature_ids, &best_splits); 87 88 int32 current_tree = 89 UpdateGlobalAttemptsAndRetrieveGrowableTree(ensemble_resource); 90 91 // No-op if no new splits can be considered. 92 if (best_splits.empty()) { 93 LOG(WARNING) << "Not growing tree ensemble as no good splits were found."; 94 return; 95 } 96 97 const int32 new_num_layers = 98 ensemble_resource->GetNumLayersGrown(current_tree) + 1; 99 VLOG(1) << "Adding layer #" << new_num_layers - 1 << " to tree #" 100 << current_tree << " of ensemble of " << current_tree + 1 101 << " trees."; 102 bool split_happened = false; 103 int32 node_id_start = ensemble_resource->GetNumNodes(current_tree); 104 // Add the splits to the tree. 105 for (auto& split_entry : best_splits) { 106 const float gain = split_entry.second.gain; 107 if (pruning_mode_ == kPrePruning) { 108 // Don't consider negative splits if we're pre-pruning the tree. 109 // Note that zero-gain splits are acceptable. 110 if (gain < 0) { 111 continue; 112 } 113 } 114 115 // unused. 116 int32 left_node_id; 117 int32 right_node_id; 118 119 ensemble_resource->AddBucketizedSplitNode(current_tree, split_entry, 120 logits_dimension, &left_node_id, 121 &right_node_id); 122 split_happened = true; 123 } 124 int32 node_id_end = ensemble_resource->GetNumNodes(current_tree); 125 if (split_happened) { 126 // Update growable tree metadata. 127 ensemble_resource->SetNumLayersGrown(current_tree, new_num_layers); 128 // Finalize the tree if needed. 129 if (ensemble_resource->GetNumLayersGrown(current_tree) >= max_depth) { 130 // If the tree is finalized, next growing will start from node 0; 131 node_id_start = 0; 132 node_id_end = 1; 133 ensemble_resource->SetIsFinalized(current_tree, true); 134 if (pruning_mode_ == kPostPruning) { 135 ensemble_resource->PostPruneTree(current_tree, logits_dimension); 136 } 137 if (ensemble_resource->num_trees() > 0) { 138 // Create a dummy new tree with an empty node. 139 ensemble_resource->AddNewTree(kLayerByLayerTreeWeight, 1); 140 } 141 } 142 // If we managed to split, update the node range. If we didn't, don't 143 // update as we will try to split the same nodes with new instances. 144 ensemble_resource->UpdateLastLayerNodesRange(node_id_start, node_id_end); 145 } 146 } 147 148 private: UpdateGlobalAttemptsAndRetrieveGrowableTree(const core::RefCountPtr<BoostedTreesEnsembleResource> & resource)149 int32 UpdateGlobalAttemptsAndRetrieveGrowableTree( 150 const core::RefCountPtr<BoostedTreesEnsembleResource>& resource) { 151 int32 num_trees = resource->num_trees(); 152 int32 current_tree = num_trees - 1; 153 154 // Increment global attempt stats. 155 resource->UpdateGrowingMetadata(); 156 157 // Note we don't set tree weight to be equal to learning rate, since we 158 // apply learning rate to leaf weights instead, when doing layer-by-layer 159 // boosting. 160 if (num_trees <= 0) { 161 // Create a new tree with a no-op leaf. 162 current_tree = resource->AddNewTree(kLayerByLayerTreeWeight, 1); 163 } 164 return current_tree; 165 } 166 167 // Helper method which effectively does a reduce over all split candidates 168 // and finds the best split for each node. FindBestSplitsPerNode(OpKernelContext * const context,const float learning_rate,const OpInputList & node_ids_list,const OpInputList & gains_list,const OpInputList & thresholds_list,const OpInputList & left_node_contribs_list,const OpInputList & right_node_contribs_list,const TTypes<const int32>::Vec & feature_ids,std::map<int32,boosted_trees::SplitCandidate> * best_split_per_node)169 void FindBestSplitsPerNode( 170 OpKernelContext* const context, const float learning_rate, 171 const OpInputList& node_ids_list, const OpInputList& gains_list, 172 const OpInputList& thresholds_list, 173 const OpInputList& left_node_contribs_list, 174 const OpInputList& right_node_contribs_list, 175 const TTypes<const int32>::Vec& feature_ids, 176 std::map<int32, boosted_trees::SplitCandidate>* best_split_per_node) { 177 // Find best split per node going through every feature candidate. 178 for (int64 feature_idx = 0; feature_idx < num_features_; ++feature_idx) { 179 const auto& node_ids = node_ids_list[feature_idx].vec<int32>(); 180 const auto& gains = gains_list[feature_idx].vec<float>(); 181 const auto& thresholds = thresholds_list[feature_idx].vec<int32>(); 182 const auto& left_node_contribs = 183 left_node_contribs_list[feature_idx].matrix<float>(); 184 const auto& right_node_contribs = 185 right_node_contribs_list[feature_idx].matrix<float>(); 186 187 for (size_t candidate_idx = 0; candidate_idx < node_ids.size(); 188 ++candidate_idx) { 189 // Get current split candidate. 190 const auto& node_id = node_ids(candidate_idx); 191 const auto& gain = gains(candidate_idx); 192 const auto& best_split_it = best_split_per_node->find(node_id); 193 boosted_trees::SplitCandidate candidate; 194 candidate.feature_id = feature_ids(feature_idx); 195 candidate.candidate_idx = candidate_idx; 196 candidate.gain = gain; 197 candidate.dimension_id = 0; 198 candidate.threshold = thresholds(candidate_idx); 199 candidate.left_node_contribs.push_back( 200 learning_rate * left_node_contribs(candidate_idx, 0)); 201 candidate.right_node_contribs.push_back( 202 learning_rate * right_node_contribs(candidate_idx, 0)); 203 candidate.split_type = boosted_trees::SplitTypeWithDefault_Name( 204 boosted_trees::INEQUALITY_DEFAULT_LEFT); 205 206 if (TF_PREDICT_FALSE(best_split_it != best_split_per_node->end() && 207 GainsAreEqual(gain, best_split_it->second.gain))) { 208 const auto best_candidate = (*best_split_per_node)[node_id]; 209 const int32 best_feature_id = best_candidate.feature_id; 210 const int32 feature_id = candidate.feature_id; 211 VLOG(2) << "Breaking ties on feature ids and buckets"; 212 // Breaking ties deterministically. 213 if (feature_id < best_feature_id) { 214 (*best_split_per_node)[node_id] = candidate; 215 } 216 } else if (best_split_it == best_split_per_node->end() || 217 GainIsLarger(gain, best_split_it->second.gain)) { 218 (*best_split_per_node)[node_id] = candidate; 219 } 220 } 221 } 222 } 223 224 private: 225 int32 num_features_; 226 PruningMode pruning_mode_; 227 }; 228 229 REGISTER_KERNEL_BUILDER(Name("BoostedTreesUpdateEnsemble").Device(DEVICE_CPU), 230 BoostedTreesUpdateEnsembleOp); 231 232 // V2 of UpdateEnsembleOp that takes in split type and feature dimension id. 233 class BoostedTreesUpdateEnsembleV2Op : public OpKernel { 234 public: BoostedTreesUpdateEnsembleV2Op(OpKernelConstruction * const context)235 explicit BoostedTreesUpdateEnsembleV2Op(OpKernelConstruction* const context) 236 : OpKernel(context) { 237 OP_REQUIRES_OK(context, context->GetAttr("logits_dimension", &logits_dim_)); 238 OP_REQUIRES_OK(context, context->GetAttr("num_groups", &num_groups_)); 239 } 240 Compute(OpKernelContext * const context)241 void Compute(OpKernelContext* const context) override { 242 // Get decision tree ensemble. 243 core::RefCountPtr<BoostedTreesEnsembleResource> ensemble_resource; 244 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 245 &ensemble_resource)); 246 mutex_lock l(*ensemble_resource->get_mutex()); 247 // Increase the ensemble stamp. 248 ensemble_resource->set_stamp(ensemble_resource->stamp() + 1); 249 250 // Read node ids, gains, thresholds and node contribs. 251 OpInputList node_ids_list; 252 OpInputList gains_list; 253 OpInputList thresholds_list; 254 OpInputList dimension_ids_list; 255 OpInputList left_node_contribs_list; 256 OpInputList right_node_contribs_list; 257 OpInputList split_types_list; 258 OP_REQUIRES_OK(context, context->input_list("node_ids", &node_ids_list)); 259 OP_REQUIRES_OK(context, context->input_list("gains", &gains_list)); 260 OP_REQUIRES_OK(context, 261 context->input_list("thresholds", &thresholds_list)); 262 OP_REQUIRES_OK(context, 263 context->input_list("dimension_ids", &dimension_ids_list)); 264 OP_REQUIRES_OK(context, context->input_list("left_node_contribs", 265 &left_node_contribs_list)); 266 OP_REQUIRES_OK(context, context->input_list("right_node_contribs", 267 &right_node_contribs_list)); 268 OP_REQUIRES_OK(context, 269 context->input_list("split_types", &split_types_list)); 270 271 OpInputList feature_ids_list; 272 OP_REQUIRES_OK(context, 273 context->input_list("feature_ids", &feature_ids_list)); 274 275 const Tensor* max_depth_t; 276 OP_REQUIRES_OK(context, context->input("max_depth", &max_depth_t)); 277 const auto max_depth = max_depth_t->scalar<int32>()(); 278 279 const Tensor* learning_rate_t; 280 OP_REQUIRES_OK(context, context->input("learning_rate", &learning_rate_t)); 281 const auto learning_rate = learning_rate_t->scalar<float>()(); 282 283 const Tensor* pruning_mode_t; 284 OP_REQUIRES_OK(context, context->input("pruning_mode", &pruning_mode_t)); 285 const auto pruning_mode = 286 static_cast<PruningMode>(pruning_mode_t->scalar<int32>()()); 287 // Find best splits for each active node. 288 std::map<int32, boosted_trees::SplitCandidate> best_splits; 289 FindBestSplitsPerNode(context, learning_rate, node_ids_list, gains_list, 290 thresholds_list, dimension_ids_list, 291 left_node_contribs_list, right_node_contribs_list, 292 split_types_list, feature_ids_list, &best_splits); 293 294 int32 current_tree = 295 UpdateGlobalAttemptsAndRetrieveGrowableTree(ensemble_resource); 296 297 // No-op if no new splits can be considered. 298 if (best_splits.empty()) { 299 LOG(WARNING) << "Not growing tree ensemble as no good splits were found."; 300 return; 301 } 302 303 const int32 new_num_layers = 304 ensemble_resource->GetNumLayersGrown(current_tree) + 1; 305 VLOG(1) << "Adding layer #" << new_num_layers - 1 << " to tree #" 306 << current_tree << " of ensemble of " << current_tree + 1 307 << " trees."; 308 bool split_happened = false; 309 int32 node_id_start = ensemble_resource->GetNumNodes(current_tree); 310 // Add the splits to the tree. 311 for (auto& split_entry : best_splits) { 312 const float gain = split_entry.second.gain; 313 const string split_type = split_entry.second.split_type; 314 315 if (pruning_mode == kPrePruning) { 316 // Don't consider negative splits if we're pre-pruning the tree. 317 // Note that zero-gain splits are acceptable. 318 if (gain < 0) { 319 continue; 320 } 321 } 322 323 // unused. 324 int32 left_node_id; 325 int32 right_node_id; 326 327 boosted_trees::SplitTypeWithDefault split_type_with_default; 328 bool parsed = boosted_trees::SplitTypeWithDefault_Parse( 329 split_type, &split_type_with_default); 330 DCHECK(parsed); 331 if (split_type_with_default == boosted_trees::EQUALITY_DEFAULT_RIGHT) { 332 // Add equality split to the node. 333 ensemble_resource->AddCategoricalSplitNode(current_tree, split_entry, 334 logits_dim_, &left_node_id, 335 &right_node_id); 336 } else { 337 // Add inequality split to the node. 338 ensemble_resource->AddBucketizedSplitNode(current_tree, split_entry, 339 logits_dim_, &left_node_id, 340 &right_node_id); 341 } 342 split_happened = true; 343 } 344 int32 node_id_end = ensemble_resource->GetNumNodes(current_tree); 345 if (split_happened) { 346 // Update growable tree metadata. 347 ensemble_resource->SetNumLayersGrown(current_tree, new_num_layers); 348 // Finalize the tree if needed. 349 if (ensemble_resource->GetNumLayersGrown(current_tree) >= max_depth) { 350 // If the tree is finalized, next growing will start from node 0; 351 node_id_start = 0; 352 node_id_end = 1; 353 ensemble_resource->SetIsFinalized(current_tree, true); 354 if (pruning_mode == kPostPruning) { 355 ensemble_resource->PostPruneTree(current_tree, logits_dim_); 356 } 357 if (ensemble_resource->num_trees() > 0) { 358 // Create a dummy new tree with an empty node. 359 ensemble_resource->AddNewTree(kLayerByLayerTreeWeight, logits_dim_); 360 } 361 } 362 // If we managed to split, update the node range. If we didn't, don't 363 // update as we will try to split the same nodes with new instances. 364 ensemble_resource->UpdateLastLayerNodesRange(node_id_start, node_id_end); 365 } 366 } 367 368 private: UpdateGlobalAttemptsAndRetrieveGrowableTree(const core::RefCountPtr<BoostedTreesEnsembleResource> & resource)369 int32 UpdateGlobalAttemptsAndRetrieveGrowableTree( 370 const core::RefCountPtr<BoostedTreesEnsembleResource>& resource) { 371 int32 num_trees = resource->num_trees(); 372 int32 current_tree = num_trees - 1; 373 374 // Increment global attempt stats. 375 resource->UpdateGrowingMetadata(); 376 377 // Note we don't set tree weight to be equal to learning rate, since we 378 // apply learning rate to leaf weights instead, when doing layer-by-layer 379 // boosting. 380 if (num_trees <= 0) { 381 // Create a new tree with a no-op leaf. 382 current_tree = resource->AddNewTree(kLayerByLayerTreeWeight, logits_dim_); 383 } 384 return current_tree; 385 } 386 387 // Helper method which effectively does a reduce over all split candidates 388 // and finds the best split for each node. FindBestSplitsPerNode(OpKernelContext * const context,const float learning_rate,const OpInputList & node_ids_list,const OpInputList & gains_list,const OpInputList & thresholds_list,const OpInputList & dimension_ids_list,const OpInputList & left_node_contribs_list,const OpInputList & right_node_contribs_list,const OpInputList & split_types_list,const OpInputList & feature_ids_list,std::map<int32,boosted_trees::SplitCandidate> * best_split_per_node)389 void FindBestSplitsPerNode( 390 OpKernelContext* const context, const float learning_rate, 391 const OpInputList& node_ids_list, const OpInputList& gains_list, 392 const OpInputList& thresholds_list, const OpInputList& dimension_ids_list, 393 const OpInputList& left_node_contribs_list, 394 const OpInputList& right_node_contribs_list, 395 const OpInputList& split_types_list, const OpInputList& feature_ids_list, 396 std::map<int32, boosted_trees::SplitCandidate>* best_split_per_node) { 397 // Find best split per node going through every feature candidate. 398 for (int64 group_idx = 0; group_idx < num_groups_; ++group_idx) { 399 const auto& node_ids = node_ids_list[group_idx].vec<int32>(); 400 const auto& gains = gains_list[group_idx].vec<float>(); 401 const auto& feature_ids = feature_ids_list[group_idx].vec<int32>(); 402 const auto& thresholds = thresholds_list[group_idx].vec<int32>(); 403 const auto& dimension_ids = dimension_ids_list[group_idx].vec<int32>(); 404 const auto& left_node_contribs = 405 left_node_contribs_list[group_idx].matrix<float>(); 406 const auto& right_node_contribs = 407 right_node_contribs_list[group_idx].matrix<float>(); 408 const auto& split_types = split_types_list[group_idx].vec<tstring>(); 409 410 for (size_t candidate_idx = 0; candidate_idx < node_ids.size(); 411 ++candidate_idx) { 412 // Get current split candidate. 413 const auto& node_id = node_ids(candidate_idx); 414 const auto& gain = gains(candidate_idx); 415 const auto& feature_id = feature_ids(candidate_idx); 416 417 auto best_split_it = best_split_per_node->find(node_id); 418 boosted_trees::SplitCandidate candidate; 419 candidate.candidate_idx = candidate_idx; 420 candidate.gain = gain; 421 candidate.feature_id = feature_id; 422 candidate.threshold = thresholds(candidate_idx); 423 candidate.dimension_id = dimension_ids(candidate_idx); 424 candidate.split_type = split_types(candidate_idx); 425 for (int i = 0; i < logits_dim_; ++i) { 426 candidate.left_node_contribs.push_back( 427 learning_rate * left_node_contribs(candidate_idx, i)); 428 candidate.right_node_contribs.push_back( 429 learning_rate * right_node_contribs(candidate_idx, i)); 430 } 431 if (TF_PREDICT_FALSE(best_split_it != best_split_per_node->end() && 432 GainsAreEqual(gain, best_split_it->second.gain))) { 433 const auto& best_candidate = (*best_split_per_node)[node_id]; 434 const int32 best_feature_id = best_candidate.feature_id; 435 const int32 feature_id = candidate.feature_id; 436 VLOG(2) << "Breaking ties on feature ids and buckets"; 437 // Breaking ties deterministically. 438 if (feature_id < best_feature_id) { 439 (*best_split_per_node)[node_id] = candidate; 440 } 441 } else if (best_split_it == best_split_per_node->end() || 442 GainIsLarger(gain, best_split_it->second.gain)) { 443 (*best_split_per_node)[node_id] = candidate; 444 } 445 } 446 } 447 } 448 449 private: 450 int32 logits_dim_; 451 int32 num_groups_; 452 }; 453 454 REGISTER_KERNEL_BUILDER(Name("BoostedTreesUpdateEnsembleV2").Device(DEVICE_CPU), 455 BoostedTreesUpdateEnsembleV2Op); 456 457 class BoostedTreesCenterBiasOp : public OpKernel { 458 public: BoostedTreesCenterBiasOp(OpKernelConstruction * const context)459 explicit BoostedTreesCenterBiasOp(OpKernelConstruction* const context) 460 : OpKernel(context) {} 461 Compute(OpKernelContext * const context)462 void Compute(OpKernelContext* const context) override { 463 // Get decision tree ensemble. 464 core::RefCountPtr<BoostedTreesEnsembleResource> ensemble_resource; 465 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 466 &ensemble_resource)); 467 mutex_lock l(*ensemble_resource->get_mutex()); 468 // Increase the ensemble stamp. 469 ensemble_resource->set_stamp(ensemble_resource->stamp() + 1); 470 471 // Read means of hessians and gradients 472 const Tensor* mean_gradients_t; 473 OP_REQUIRES_OK(context, 474 context->input("mean_gradients", &mean_gradients_t)); 475 const int32 logits_dim = mean_gradients_t->dim_size(1); 476 const Tensor* mean_hessians_t; 477 OP_REQUIRES_OK(context, context->input("mean_hessians", &mean_hessians_t)); 478 479 // Get the regularization options. 480 const Tensor* l1_t; 481 OP_REQUIRES_OK(context, context->input("l1", &l1_t)); 482 const auto l1 = l1_t->scalar<float>()(); 483 const Tensor* l2_t; 484 OP_REQUIRES_OK(context, context->input("l2", &l2_t)); 485 const auto l2 = l2_t->scalar<float>()(); 486 487 // For now, assume 1-dimensional weight on leaves. 488 Eigen::VectorXf logits_vector(1); 489 float unused_gain; 490 491 // TODO(crawles): Support multiclass. 492 DCHECK_EQ(logits_dim, 1); 493 Eigen::VectorXf gradients_mean(1); 494 Eigen::VectorXf hessians_mean(1); 495 gradients_mean[0] = mean_gradients_t->flat<float>()(0); 496 hessians_mean[0] = mean_hessians_t->flat<float>()(0); 497 CalculateWeightsAndGains(gradients_mean, hessians_mean, l1, l2, 498 &logits_vector, &unused_gain); 499 const float logits = logits_vector[0]; 500 501 float current_bias = 0.0; 502 bool continue_centering = true; 503 if (ensemble_resource->num_trees() == 0) { 504 ensemble_resource->AddNewTreeWithLogits(kLayerByLayerTreeWeight, {logits}, 505 1); 506 current_bias = logits; 507 } else { 508 const auto& current_biases = ensemble_resource->node_value(0, 0); 509 DCHECK_EQ(current_biases.size(), 1); 510 current_bias = current_biases[0]; 511 continue_centering = 512 std::abs(logits / current_bias) > kMinDeltaForCenterBias; 513 current_bias += logits; 514 ensemble_resource->set_node_value(0, 0, current_bias); 515 } 516 517 Tensor* continue_centering_t = nullptr; 518 OP_REQUIRES_OK( 519 context, context->allocate_output("continue_centering", TensorShape({}), 520 &continue_centering_t)); 521 // Check if we need to continue centering bias. 522 continue_centering_t->scalar<bool>()() = continue_centering; 523 } 524 }; 525 REGISTER_KERNEL_BUILDER(Name("BoostedTreesCenterBias").Device(DEVICE_CPU), 526 BoostedTreesCenterBiasOp); 527 528 } // namespace tensorflow 529