1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Implementation of handler for split nodes for float columns. 16 17The general idea in batch split finding is that each handler will accumulate its 18own statistics on multiple workers. After some steps, the master runs 19make_splits() sub-graph of each handler and each handler returns its best split 20per partition. 21 22The way we ensure consistency of statistics is by using stamp_tokens for read 23and write operations. During each update of the model, a new stamp token is 24created. This stamp token makes sure that updates from the previous iterations 25are not included in the statistics for this iteration. 26 27Inequality splits for float features are created similar to the method described 28in Approximate Algorithm described in https://arxiv.org/pdf/1603.02754v3.pdf. 29Weighted quantiles of the feature columns are computed in a distributed fashion 30using quantile_ops.quantile_accumulator. 31After certain number of steps of parallel accumulation of quantile statistics, 32we decide on bucket boundaries. These bucket boundaries are then used for the 33next N steps to accumulate gradients and hessians per bucket. 34 35In this implementation, we gather quantile statistics and gradient statistics 36concurrently. That means that we don't wait until we have enough quantile 37statistics for bucketization before we start gathering gradient stats. Instead 38during each step we create quantile stats for the next iteration and use the 39previous quantile buckets for gradient stats accumulation. 40In make_splits, we do these steps: 411) Get the buckets that were used creating for the gradient stats. 422) Create bucket boundaries for the next N iterations and clear the accumulated 43 quantile stats. 44n3) Get the accumulated gradient stats and clear the accumulator. This step can 45 run in parallel to step 2. 464) For each leaf node in the current tree (partition): 47 4.1) Get the overall gain computed with gradients and hessians of all 48 examples that end up in this partition. 49 4.2) Compute tensors of left and right cumulative sum of gradients, hessians 50 and gain. The first dimension of these tensors are the bucket 51 boundaries. 52 4.3) Find the gains for all bucket boundaries: 53 split_gains = left_gain + right_gain - overall_gain. 54 4.4) Find the bucket boundary that has the best gain (argmax(split_gains)) 55 4.5) For Sparse handler, we also consider the gain for when the examples go 56 the left child and when the examples go to the right child and pick the 57 default direction that yields the most gain. 58""" 59 60from __future__ import absolute_import 61from __future__ import division 62from __future__ import print_function 63 64import re 65 66from tensorflow.contrib.boosted_trees.lib.learner.batch import base_split_handler 67from tensorflow.contrib.boosted_trees.proto import learner_pb2 68from tensorflow.contrib.boosted_trees.python.ops import gen_quantile_ops 69from tensorflow.contrib.boosted_trees.python.ops import gen_stats_accumulator_ops 70from tensorflow.contrib.boosted_trees.python.ops import quantile_ops 71from tensorflow.contrib.boosted_trees.python.ops import split_handler_ops 72from tensorflow.contrib.boosted_trees.python.ops import stats_accumulator_ops 73from tensorflow.python.framework import constant_op 74from tensorflow.python.framework import dtypes 75from tensorflow.python.framework import function 76from tensorflow.python.framework import ops 77from tensorflow.python.framework import sparse_tensor 78from tensorflow.python.framework import tensor_shape 79from tensorflow.python.ops import array_ops 80from tensorflow.python.ops import control_flow_ops 81from tensorflow.python.ops import math_ops 82 83 84_BIAS_FEATURE_ID = -1 85# Pattern to remove all non alpha numeric from a string. 86_PATTERN = re.compile(r"[\W_]+") 87 88 89class InequalitySplitHandler(base_split_handler.BaseSplitHandler): 90 """Base class for handlers of inequality splits.""" 91 92 def __init__(self, 93 l1_regularization, 94 l2_regularization, 95 tree_complexity_regularization, 96 min_node_weight, 97 feature_column_group_id, 98 epsilon, 99 num_quantiles, 100 gradient_shape, 101 hessian_shape, 102 multiclass_strategy, 103 init_stamp_token=0, 104 loss_uses_sum_reduction=False, 105 name=None): 106 """Initialize the internal state for this split handler. 107 108 Args: 109 l1_regularization: L1 regularization applied for this split handler. 110 l2_regularization: L2 regularization applied for this split handler. 111 tree_complexity_regularization: Tree complexity regularization applied 112 for this split handler. 113 min_node_weight: Minimum sum of weights of examples in each partition to 114 be considered for splitting. 115 feature_column_group_id: Feature column group index. 116 epsilon: A float, the error bound for quantile computation. 117 num_quantiles: An int, the number of buckets to create from the histogram. 118 gradient_shape: A TensorShape, containing shape of gradients. 119 hessian_shape: A TensorShape, containing shape of hessians. 120 multiclass_strategy: Strategy describing how to treat multiclass problems. 121 init_stamp_token: A tensor containing an scalar for initial stamp of the 122 stamped objects. 123 loss_uses_sum_reduction: A scalar boolean tensor that specifies whether 124 SUM or MEAN reduction was used for the loss. 125 name: An optional handler name. 126 """ 127 super(InequalitySplitHandler, self).__init__( 128 name=name, 129 l1_regularization=l1_regularization, 130 l2_regularization=l2_regularization, 131 tree_complexity_regularization=tree_complexity_regularization, 132 min_node_weight=min_node_weight, 133 feature_column_group_id=feature_column_group_id, 134 gradient_shape=gradient_shape, 135 hessian_shape=hessian_shape, 136 multiclass_strategy=multiclass_strategy, 137 loss_uses_sum_reduction=loss_uses_sum_reduction) 138 self._stats_accumulator = stats_accumulator_ops.StatsAccumulator( 139 init_stamp_token, 140 gradient_shape, 141 hessian_shape, 142 name="StatsAccumulator/{}".format(self._name)) 143 # Allocate both stats accumulator and quantile accumulator on the same 144 # device so that we can build splits with fewer RPCs. 145 with ops.colocate_with(self._stats_accumulator.resource_handle): 146 self._quantile_accumulator = quantile_ops.QuantileAccumulator( 147 init_stamp_token, 148 epsilon=epsilon, 149 num_quantiles=num_quantiles, 150 name="QuantileAccumulator/{}".format(self._name)) 151 152 def reset(self, stamp_token, next_stamp_token): 153 reset_1 = self._stats_accumulator.flush(stamp_token, next_stamp_token) 154 reset_2 = self._quantile_accumulator.flush(stamp_token, next_stamp_token) 155 return control_flow_ops.group([reset_1, reset_2]) 156 157 158class DenseSplitHandler(InequalitySplitHandler): 159 """Computes stats and finds the best inequality splits on dense columns.""" 160 161 def __init__(self, 162 dense_float_column, 163 l1_regularization, 164 l2_regularization, 165 tree_complexity_regularization, 166 min_node_weight, 167 feature_column_group_id, 168 epsilon, 169 num_quantiles, 170 gradient_shape, 171 hessian_shape, 172 multiclass_strategy, 173 init_stamp_token=0, 174 loss_uses_sum_reduction=False, 175 weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE, 176 name=None): 177 """Initialize the internal state for this split handler. 178 179 Args: 180 dense_float_column: A `Tensor` column associated with this handler. 181 l1_regularization: L1 regularization applied for this split handler. 182 l2_regularization: L2 regularization applied for this split handler. 183 tree_complexity_regularization: Tree complexity regularization applied 184 for this split handler. 185 min_node_weight: Minimum sum of weights of examples in each partition to 186 be considered for splitting. 187 feature_column_group_id: Feature column group index. 188 epsilon: A float, the error bound for quantile computation. 189 num_quantiles: An int, the number of buckets to create from the histogram. 190 gradient_shape: A TensorShape, containing shape of gradients. 191 hessian_shape: A TensorShape, containing shape of hessians. 192 multiclass_strategy: Strategy describing how to treat multiclass problems. 193 init_stamp_token: A tensor containing an scalar for initial stamp of the 194 stamped objects. 195 loss_uses_sum_reduction: A scalar boolean tensor that specifies whether 196 SUM or MEAN reduction was used for the loss. 197 weak_learner_type: Specifies the type of weak learner to use. 198 name: An optional handler name. 199 """ 200 super(DenseSplitHandler, self).__init__( 201 l1_regularization=l1_regularization, 202 l2_regularization=l2_regularization, 203 tree_complexity_regularization=tree_complexity_regularization, 204 min_node_weight=min_node_weight, 205 feature_column_group_id=feature_column_group_id, 206 epsilon=epsilon, 207 num_quantiles=num_quantiles, 208 init_stamp_token=init_stamp_token, 209 name=name, 210 gradient_shape=gradient_shape, 211 hessian_shape=hessian_shape, 212 multiclass_strategy=multiclass_strategy, 213 loss_uses_sum_reduction=loss_uses_sum_reduction) 214 self._dense_float_column = dense_float_column 215 self._weak_learner_type = weak_learner_type 216 # Register dense_make_stats_update function as an Op to the graph. 217 g = ops.get_default_graph() 218 dense_make_stats_update.add_to_graph(g) 219 220 def scheduled_reads(self): 221 return [self._quantile_accumulator.schedule_get_buckets()] 222 223 def update_stats(self, stamp_token, example_partition_ids, gradients, 224 hessians, empty_gradients, empty_hessians, weights, 225 is_active, scheduled_reads): 226 """Updates the state for dense split handler. 227 228 Args: 229 stamp_token: An int32 scalar tensor containing the current stamp token. 230 example_partition_ids: A dense tensor, containing an int32 for each 231 example which is the partition id that the example ends up in. 232 gradients: A dense tensor of gradients. 233 hessians: A dense tensor of hessians. 234 empty_gradients: A dense empty tensor of the same shape (for dimensions > 235 0) as gradients. 236 empty_hessians: A dense empty tensor of the same shape (for dimensions > 237 0) as hessians. 238 weights: A dense float32 tensor with a weight for each example. 239 is_active: A boolean tensor that says if this handler is active or not. 240 One value for the current layer and one value for the next layer. 241 scheduled_reads: List of scheduled reads for this handler. 242 243 Returns: 244 The op that updates the stats for this handler. 245 """ 246 name = _PATTERN.sub("", self._name) 247 with ops.name_scope(name, "DenseSplitHandler"): 248 are_buckets_ready, buckets = scheduled_reads[0] 249 (quantile_values, quantile_weights, example_partition_ids, 250 feature_ids, gradients, hessians) = dense_make_stats_update( 251 is_active, are_buckets_ready, self._dense_float_column, buckets, 252 example_partition_ids, gradients, hessians, weights, empty_gradients, 253 empty_hessians) 254 update_quantiles = self._quantile_accumulator.schedule_add_summary( 255 stamp_token=stamp_token, 256 column=quantile_values, 257 example_weights=quantile_weights) 258 update_stats = self._stats_accumulator.schedule_add( 259 example_partition_ids, feature_ids, gradients, hessians) 260 return control_flow_ops.no_op(), [update_quantiles, update_stats] 261 262 def make_splits(self, stamp_token, next_stamp_token, class_id): 263 """Create the best split using the accumulated stats and flush the state.""" 264 if (self._gradient_shape == tensor_shape.scalar() and 265 self._hessian_shape == tensor_shape.scalar()): 266 handler = make_dense_split_scalar 267 else: 268 handler = make_dense_split_tensor 269 270 are_splits_ready, partition_ids, gains, split_infos = ( 271 handler(self._quantile_accumulator.resource_handle, 272 self._stats_accumulator.resource_handle, stamp_token, 273 next_stamp_token, self._multiclass_strategy, class_id, 274 self._feature_column_group_id, self._l1_regularization, 275 self._l2_regularization, self._tree_complexity_regularization, 276 self._min_node_weight, self._loss_uses_sum_reduction, 277 self._weak_learner_type)) 278 return are_splits_ready, partition_ids, gains, split_infos 279 280 281def _make_dense_split(quantile_accumulator_handle, stats_accumulator_handle, 282 stamp_token, next_stamp_token, multiclass_strategy, 283 class_id, feature_column_id, l1_regularization, 284 l2_regularization, tree_complexity_regularization, 285 min_node_weight, is_multi_dimentional, 286 loss_uses_sum_reduction, weak_learner_type): 287 """Function that builds splits for a dense feature column.""" 288 # Get the bucket boundaries 289 are_splits_ready, buckets = ( 290 gen_quantile_ops.quantile_accumulator_get_buckets( 291 quantile_accumulator_handles=[quantile_accumulator_handle], 292 stamp_token=stamp_token)) 293 # quantile_accumulator_get_buckets returns a list of results per handle that 294 # we pass to it. In this case we're getting results just for one resource. 295 are_splits_ready = are_splits_ready[0] 296 buckets = buckets[0] 297 298 # After we receive the boundaries from previous iteration we can flush 299 # the quantile accumulator. 300 with ops.control_dependencies([buckets]): 301 flush_quantiles = gen_quantile_ops.quantile_accumulator_flush( 302 quantile_accumulator_handle=quantile_accumulator_handle, 303 stamp_token=stamp_token, 304 next_stamp_token=next_stamp_token) 305 306 if is_multi_dimentional: 307 num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( 308 gen_stats_accumulator_ops.stats_accumulator_tensor_flush( 309 stats_accumulator_handle, stamp_token, next_stamp_token)) 310 else: 311 num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( 312 gen_stats_accumulator_ops.stats_accumulator_scalar_flush( 313 stats_accumulator_handle, stamp_token, next_stamp_token)) 314 # For sum_reduction, we don't need to divide by number of minibatches. 315 num_minibatches = control_flow_ops.cond( 316 loss_uses_sum_reduction, 317 lambda: math_ops.cast(1, dtypes.int64), 318 lambda: num_minibatches) 319 # Put quantile and stats accumulator flushing in the dependency path. 320 with ops.control_dependencies([flush_quantiles, partition_ids]): 321 are_splits_ready = array_ops.identity(are_splits_ready) 322 partition_ids, gains, split_infos = ( 323 split_handler_ops.build_dense_inequality_splits( 324 num_minibatches=num_minibatches, 325 bucket_boundaries=buckets, 326 partition_ids=partition_ids, 327 bucket_ids=bucket_ids, 328 gradients=gradients, 329 hessians=hessians, 330 class_id=class_id, 331 feature_column_group_id=feature_column_id, 332 l1_regularization=l1_regularization, 333 l2_regularization=l2_regularization, 334 tree_complexity_regularization=tree_complexity_regularization, 335 min_node_weight=min_node_weight, 336 multiclass_strategy=multiclass_strategy, 337 weak_learner_type=weak_learner_type)) 338 return are_splits_ready, partition_ids, gains, split_infos 339 340 341class SparseSplitHandler(InequalitySplitHandler): 342 """Computes stats and finds the best inequality splits on sparse columns.""" 343 344 def __init__(self, 345 sparse_float_column, 346 l1_regularization, 347 l2_regularization, 348 tree_complexity_regularization, 349 min_node_weight, 350 feature_column_group_id, 351 epsilon, 352 num_quantiles, 353 gradient_shape, 354 hessian_shape, 355 multiclass_strategy, 356 init_stamp_token=0, 357 loss_uses_sum_reduction=False, 358 name=None): 359 """Initialize the internal state for this split handler. 360 361 Args: 362 sparse_float_column: A `SparseTensor` column associated with this handler. 363 l1_regularization: L1 regularization applied for this split handler. 364 l2_regularization: L2 regularization applied for this split handler. 365 tree_complexity_regularization: Tree complexity regularization applied 366 for this split handler. 367 min_node_weight: Minimum sum of weights of examples in each partition to 368 be considered for splitting. 369 feature_column_group_id: Feature column group index. 370 epsilon: A float, the error bound for quantile computation. 371 num_quantiles: An int, the number of buckets to create from the histogram. 372 gradient_shape: A TensorShape, containing shape of gradients. 373 hessian_shape: A TensorShape, containing shape of hessians. 374 multiclass_strategy: Strategy describing how to treat multiclass problems. 375 init_stamp_token: A tensor containing an scalar for initial stamp of the 376 stamped objects. 377 loss_uses_sum_reduction: A scalar boolean tensor that specifies whether 378 SUM or MEAN reduction was used for the loss. 379 name: An optional handler name. 380 """ 381 super(SparseSplitHandler, self).__init__( 382 l1_regularization=l1_regularization, 383 l2_regularization=l2_regularization, 384 tree_complexity_regularization=tree_complexity_regularization, 385 min_node_weight=min_node_weight, 386 feature_column_group_id=feature_column_group_id, 387 epsilon=epsilon, 388 num_quantiles=num_quantiles, 389 gradient_shape=gradient_shape, 390 hessian_shape=hessian_shape, 391 multiclass_strategy=multiclass_strategy, 392 init_stamp_token=init_stamp_token, 393 loss_uses_sum_reduction=loss_uses_sum_reduction, 394 name=name) 395 self._sparse_float_column = sparse_float_column 396 397 def scheduled_reads(self): 398 return [self._quantile_accumulator.schedule_get_buckets()] 399 400 def update_stats(self, stamp_token, example_partition_ids, gradients, 401 hessians, empty_gradients, empty_hessians, weights, 402 is_active, scheduled_reads): 403 """Updates the state for dense split handler. 404 405 Args: 406 stamp_token: An int32 scalar tensor containing the current stamp token. 407 example_partition_ids: A dense tensor, containing an int32 for each 408 example which is the partition id that the example ends up in. 409 gradients: A dense tensor of gradients. 410 hessians: A dense tensor of hessians. 411 empty_gradients: A dense empty tensor of the same shape (for dimensions > 412 0) as gradients. 413 empty_hessians: A dense empty tensor of the same shape (for dimensions > 414 0) as hessians. 415 weights: A dense float32 tensor with a weight for each example. 416 is_active: A boolean tensor that says if this handler is active or not. 417 One value for the current layer and one value for the next layer. 418 scheduled_reads: List of results from the scheduled reads. 419 420 Returns: 421 The op that updates the stats for this handler. 422 """ 423 are_buckets_ready, buckets = scheduled_reads[0] 424 with ops.name_scope(self._name, "SparseSplitHandler"): 425 (quantile_indices, quantile_values, quantile_shapes, quantile_weights, 426 example_partition_ids, feature_ids, gradients, 427 hessians) = sparse_make_stats_update( 428 is_active, are_buckets_ready, self._sparse_float_column.indices, 429 self._sparse_float_column.values, 430 self._sparse_float_column.dense_shape, buckets, 431 example_partition_ids, gradients, hessians, weights, empty_gradients, 432 empty_hessians) 433 update_quantiles = self._quantile_accumulator.schedule_add_summary( 434 stamp_token=stamp_token, 435 column=sparse_tensor.SparseTensor(quantile_indices, quantile_values, 436 quantile_shapes), 437 example_weights=quantile_weights) 438 update_stats = self._stats_accumulator.schedule_add( 439 example_partition_ids, feature_ids, gradients, hessians) 440 return (control_flow_ops.no_op(), [update_quantiles, update_stats]) 441 442 def make_splits(self, stamp_token, next_stamp_token, class_id): 443 """Create the best split using the accumulated stats and flush the state.""" 444 if (self._gradient_shape == tensor_shape.scalar() and 445 self._hessian_shape == tensor_shape.scalar()): 446 handler = make_sparse_split_scalar 447 else: 448 handler = make_sparse_split_tensor 449 450 are_splits_ready, partition_ids, gains, split_infos = ( 451 handler(self._quantile_accumulator.resource_handle, 452 self._stats_accumulator.resource_handle, stamp_token, 453 next_stamp_token, self._multiclass_strategy, class_id, 454 self._feature_column_group_id, self._l1_regularization, 455 self._l2_regularization, self._tree_complexity_regularization, 456 self._min_node_weight, self._loss_uses_sum_reduction)) 457 return are_splits_ready, partition_ids, gains, split_infos 458 459 460def _make_sparse_split( 461 quantile_accumulator_handle, stats_accumulator_handle, stamp_token, 462 next_stamp_token, multiclass_strategy, class_id, feature_column_id, 463 l1_regularization, l2_regularization, tree_complexity_regularization, 464 min_node_weight, is_multi_dimentional, loss_uses_sum_reduction): 465 """Function that builds splits for a sparse feature column.""" 466 # Get the bucket boundaries 467 are_splits_ready, buckets = ( 468 gen_quantile_ops.quantile_accumulator_get_buckets( 469 quantile_accumulator_handles=[quantile_accumulator_handle], 470 stamp_token=stamp_token)) 471 # quantile_accumulator_get_buckets returns a list of results per handle that 472 # we pass to it. In this case we're getting results just for one resource. 473 are_splits_ready = are_splits_ready[0] 474 buckets = buckets[0] 475 476 # After we receive the boundaries from previous iteration we can flush 477 # the quantile accumulator. 478 with ops.control_dependencies([buckets]): 479 flush_quantiles = gen_quantile_ops.quantile_accumulator_flush( 480 quantile_accumulator_handle=quantile_accumulator_handle, 481 stamp_token=stamp_token, 482 next_stamp_token=next_stamp_token) 483 484 if is_multi_dimentional: 485 num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( 486 gen_stats_accumulator_ops.stats_accumulator_tensor_flush( 487 stats_accumulator_handle, stamp_token, next_stamp_token)) 488 else: 489 num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( 490 gen_stats_accumulator_ops.stats_accumulator_scalar_flush( 491 stats_accumulator_handle, stamp_token, next_stamp_token)) 492 num_minibatches = control_flow_ops.cond( 493 loss_uses_sum_reduction, 494 lambda: math_ops.cast(1, dtypes.int64), 495 lambda: num_minibatches) 496 # Put quantile and stats accumulator flushing in the dependency path. 497 with ops.control_dependencies([flush_quantiles, partition_ids]): 498 are_splits_ready = array_ops.identity(are_splits_ready) 499 partition_ids, gains, split_infos = ( 500 split_handler_ops.build_sparse_inequality_splits( 501 num_minibatches=num_minibatches, 502 bucket_boundaries=buckets, 503 partition_ids=partition_ids, 504 bucket_ids=bucket_ids, 505 gradients=gradients, 506 hessians=hessians, 507 class_id=class_id, 508 feature_column_group_id=feature_column_id, 509 l1_regularization=l1_regularization, 510 l2_regularization=l2_regularization, 511 tree_complexity_regularization=tree_complexity_regularization, 512 min_node_weight=min_node_weight, 513 bias_feature_id=_BIAS_FEATURE_ID, 514 multiclass_strategy=multiclass_strategy)) 515 return are_splits_ready, partition_ids, gains, split_infos 516 517 518def _specialize_make_split_dense(func, is_multi_dimentional): 519 """Builds a specialized version of the function.""" 520 521 @function.Defun( 522 dtypes.resource, 523 dtypes.resource, 524 dtypes.int64, 525 dtypes.int64, 526 dtypes.int32, 527 dtypes.int32, 528 dtypes.int32, 529 dtypes.float32, 530 dtypes.float32, 531 dtypes.float32, 532 dtypes.float32, 533 dtypes.bool, 534 dtypes.int32, 535 noinline=True) 536 def f(quantile_accumulator_handle, stats_accumulator_handle, stamp_token, 537 next_stamp_token, multiclass_strategy, class_id, feature_column_id, 538 l1_regularization, l2_regularization, tree_complexity_regularization, 539 min_node_weight, loss_uses_sum_reduction, weak_learner_type): 540 """Function that builds splits for a sparse feature column.""" 541 return func(quantile_accumulator_handle, stats_accumulator_handle, 542 stamp_token, next_stamp_token, multiclass_strategy, class_id, 543 feature_column_id, l1_regularization, l2_regularization, 544 tree_complexity_regularization, min_node_weight, 545 is_multi_dimentional, loss_uses_sum_reduction, 546 weak_learner_type) 547 548 return f 549 550 551def _specialize_make_split_sparse(func, is_multi_dimentional): 552 """Builds a specialized version of the function.""" 553 554 @function.Defun( 555 dtypes.resource, 556 dtypes.resource, 557 dtypes.int64, 558 dtypes.int64, 559 dtypes.int32, 560 dtypes.int32, 561 dtypes.int32, 562 dtypes.float32, 563 dtypes.float32, 564 dtypes.float32, 565 dtypes.float32, 566 dtypes.bool, 567 noinline=True) 568 def f(quantile_accumulator_handle, stats_accumulator_handle, stamp_token, 569 next_stamp_token, multiclass_strategy, class_id, feature_column_id, 570 l1_regularization, l2_regularization, tree_complexity_regularization, 571 min_node_weight, loss_uses_sum_reduction): 572 """Function that builds splits for a sparse feature column.""" 573 return func(quantile_accumulator_handle, stats_accumulator_handle, 574 stamp_token, next_stamp_token, multiclass_strategy, class_id, 575 feature_column_id, l1_regularization, l2_regularization, 576 tree_complexity_regularization, min_node_weight, 577 is_multi_dimentional, loss_uses_sum_reduction) 578 579 return f 580 581 582make_dense_split_scalar = _specialize_make_split_dense( 583 _make_dense_split, is_multi_dimentional=False) 584 585make_dense_split_tensor = _specialize_make_split_dense( 586 _make_dense_split, is_multi_dimentional=True) 587 588make_sparse_split_scalar = _specialize_make_split_sparse( 589 _make_sparse_split, is_multi_dimentional=False) 590make_sparse_split_tensor = _specialize_make_split_sparse( 591 _make_sparse_split, is_multi_dimentional=True) 592 593 594@function.Defun( 595 dtypes.bool, 596 dtypes.bool, 597 dtypes.float32, 598 dtypes.float32, 599 dtypes.int32, 600 dtypes.float32, 601 dtypes.float32, 602 dtypes.float32, 603 dtypes.float32, 604 dtypes.float32, 605 noinline=True) 606def dense_make_stats_update(is_active, are_buckets_ready, float_column, 607 quantile_buckets, example_partition_ids, gradients, 608 hessians, weights, empty_gradients, empty_hessians): 609 """Updates the state for dense split handler.""" 610 empty_float = constant_op.constant_v1([], dtype=dtypes.float32) 611 612 quantile_values, quantile_weights = control_flow_ops.cond( 613 is_active[1], # For the next layer, this handler is inactive. 614 lambda: (float_column, weights), 615 lambda: (empty_float, empty_float)) 616 617 def ready_inputs_fn(): 618 """Branch to execute when quantiles are ready.""" 619 quantized_feature = quantile_ops.quantiles([float_column], [], 620 [quantile_buckets], [], []) 621 quantized_feature = math_ops.cast(quantized_feature[0], dtypes.int64) 622 quantized_feature = array_ops.squeeze(quantized_feature, axis=0) 623 return (example_partition_ids, quantized_feature, gradients, hessians) 624 625 def not_ready_inputs_fn(): 626 return (constant_op.constant_v1([], dtype=dtypes.int32), 627 constant_op.constant_v1([[]], dtype=dtypes.int64, shape=[1, 2]), 628 empty_gradients, empty_hessians) 629 630 example_partition_ids, feature_ids, gradients, hessians = ( 631 control_flow_ops.cond( 632 math_ops.logical_and( 633 math_ops.logical_and(are_buckets_ready, 634 array_ops.size(quantile_buckets) > 0), 635 is_active[0]), ready_inputs_fn, not_ready_inputs_fn)) 636 return (quantile_values, quantile_weights, example_partition_ids, feature_ids, 637 gradients, hessians) 638 639 640@function.Defun( 641 dtypes.bool, 642 dtypes.bool, 643 dtypes.int64, 644 dtypes.float32, 645 dtypes.int64, 646 dtypes.float32, 647 dtypes.int32, 648 dtypes.float32, 649 dtypes.float32, 650 dtypes.float32, 651 dtypes.float32, 652 dtypes.float32, 653 noinline=True) 654def sparse_make_stats_update( 655 is_active, are_buckets_ready, sparse_column_indices, sparse_column_values, 656 sparse_column_shape, quantile_buckets, example_partition_ids, gradients, 657 hessians, weights, empty_gradients, empty_hessians): 658 """Updates the state for this split handler.""" 659 660 def quantiles_ready(): 661 """The subgraph for when the quantiles are ready.""" 662 quantized_feature = quantile_ops.quantiles([], [sparse_column_values], [], 663 [quantile_buckets], 664 [sparse_column_indices]) 665 666 quantized_feature = math_ops.cast(quantized_feature[1], dtypes.int64) 667 quantized_feature = array_ops.squeeze(quantized_feature, axis=0) 668 669 example_indices, _ = array_ops.split( 670 sparse_column_indices, num_or_size_splits=2, axis=1) 671 example_indices = array_ops.squeeze(example_indices, [1]) 672 filtered_gradients = array_ops.gather(gradients, example_indices) 673 filtered_hessians = array_ops.gather(hessians, example_indices) 674 filtered_partition_ids = array_ops.gather(example_partition_ids, 675 example_indices) 676 unique_partitions, mapped_partitions = array_ops.unique( 677 example_partition_ids) 678 679 # Compute aggregate stats for each partition. 680 # Since unsorted_segment_sum can be numerically unstable, use 64bit 681 # operation. 682 gradients64 = math_ops.cast(gradients, dtypes.float64) 683 hessians64 = math_ops.cast(hessians, dtypes.float64) 684 per_partition_gradients = math_ops.unsorted_segment_sum( 685 gradients64, mapped_partitions, array_ops.size(unique_partitions)) 686 per_partition_hessians = math_ops.unsorted_segment_sum( 687 hessians64, mapped_partitions, array_ops.size(unique_partitions)) 688 per_partition_gradients = math_ops.cast(per_partition_gradients, 689 dtypes.float32) 690 per_partition_hessians = math_ops.cast(per_partition_hessians, 691 dtypes.float32) 692 # Prepend a bias feature per partition that accumulates the stats for all 693 # examples in that partition. 694 bias_feature_ids = array_ops.fill( 695 array_ops.shape(unique_partitions), _BIAS_FEATURE_ID) 696 bias_feature_ids = math_ops.cast(bias_feature_ids, dtypes.int64) 697 zeros = array_ops.zeros_like(bias_feature_ids) 698 bias_feature_ids = array_ops.stack([bias_feature_ids, zeros], axis=1) 699 700 partition_ids = array_ops.concat( 701 [unique_partitions, filtered_partition_ids], 0) 702 filtered_gradients = array_ops.concat( 703 [per_partition_gradients, filtered_gradients], 0) 704 filtered_hessians = array_ops.concat( 705 [per_partition_hessians, filtered_hessians], 0) 706 707 bucket_ids = array_ops.concat([bias_feature_ids, quantized_feature], 0) 708 709 return partition_ids, bucket_ids, filtered_gradients, filtered_hessians 710 711 def quantiles_not_ready(): 712 """The subgraph for when the quantiles are not ready.""" 713 return (constant_op.constant_v1([], dtype=dtypes.int32), 714 constant_op.constant_v1([], dtype=dtypes.int64, shape=[1, 2]), 715 empty_gradients, empty_hessians) 716 717 empty_float = constant_op.constant_v1([], dtype=dtypes.float32) 718 handler_not_active = (constant_op.constant( 719 [], dtype=dtypes.int64, shape=[0, 2]), empty_float, 720 constant_op.constant([0, 1], dtype=dtypes.int64), 721 empty_float) 722 handler_active = (sparse_column_indices, sparse_column_values, 723 sparse_column_shape, weights) 724 quantile_indices, quantile_values, quantile_shape, quantile_weights = ( 725 control_flow_ops.cond(is_active[1], lambda: handler_active, 726 lambda: handler_not_active)) 727 728 example_partition_ids, feature_ids, gradients, hessians = ( 729 control_flow_ops.cond( 730 math_ops.logical_and(are_buckets_ready, 731 array_ops.size(quantile_buckets) > 0), 732 quantiles_ready, quantiles_not_ready)) 733 734 return (quantile_indices, quantile_values, quantile_shape, quantile_weights, 735 example_partition_ids, feature_ids, gradients, hessians) 736