• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implementation of handler for split nodes for float columns.
16
17The general idea in batch split finding is that each handler will accumulate its
18own statistics on multiple workers. After some steps, the master runs
19make_splits() sub-graph of each handler and each handler returns its best split
20per partition.
21
22The way we ensure consistency of statistics is by using stamp_tokens for read
23and write operations. During each update of the model, a new stamp token is
24created. This stamp token makes sure that updates from the previous iterations
25are not included in the statistics for this iteration.
26
27Inequality splits for float features are created similar to the method described
28in Approximate Algorithm described in https://arxiv.org/pdf/1603.02754v3.pdf.
29Weighted quantiles of the feature columns are computed in a distributed fashion
30using quantile_ops.quantile_accumulator.
31After certain number of steps of parallel accumulation of quantile statistics,
32we decide on bucket boundaries. These bucket boundaries are then used for the
33next N steps to accumulate gradients and hessians per bucket.
34
35In this implementation, we gather quantile statistics and gradient statistics
36concurrently. That means that we don't wait until we have enough quantile
37statistics for bucketization before we start gathering gradient stats. Instead
38during each step we create quantile stats for the next iteration and use the
39previous quantile buckets for gradient stats accumulation.
40In make_splits, we do these steps:
411) Get the buckets that were used creating for the gradient stats.
422) Create bucket boundaries for the next N iterations and clear the accumulated
43   quantile stats.
44n3) Get the accumulated gradient stats and clear the accumulator. This step can
45   run in parallel to step 2.
464) For each leaf node in the current tree (partition):
47   4.1) Get the overall gain computed with gradients and hessians of all
48        examples that end up in this partition.
49   4.2) Compute tensors of left and right cumulative sum of gradients, hessians
50        and gain. The first dimension of these tensors are the bucket
51        boundaries.
52   4.3) Find the gains for all bucket boundaries:
53        split_gains = left_gain + right_gain - overall_gain.
54   4.4) Find the bucket boundary that has the best gain (argmax(split_gains))
55   4.5) For Sparse handler, we also consider the gain for when the examples go
56        the left child and when the examples go to the right child and pick the
57        default direction that yields the most gain.
58"""
59
60from __future__ import absolute_import
61from __future__ import division
62from __future__ import print_function
63
64import re
65
66from tensorflow.contrib.boosted_trees.lib.learner.batch import base_split_handler
67from tensorflow.contrib.boosted_trees.proto import learner_pb2
68from tensorflow.contrib.boosted_trees.python.ops import gen_quantile_ops
69from tensorflow.contrib.boosted_trees.python.ops import gen_stats_accumulator_ops
70from tensorflow.contrib.boosted_trees.python.ops import quantile_ops
71from tensorflow.contrib.boosted_trees.python.ops import split_handler_ops
72from tensorflow.contrib.boosted_trees.python.ops import stats_accumulator_ops
73from tensorflow.python.framework import constant_op
74from tensorflow.python.framework import dtypes
75from tensorflow.python.framework import function
76from tensorflow.python.framework import ops
77from tensorflow.python.framework import sparse_tensor
78from tensorflow.python.framework import tensor_shape
79from tensorflow.python.ops import array_ops
80from tensorflow.python.ops import control_flow_ops
81from tensorflow.python.ops import math_ops
82
83
84_BIAS_FEATURE_ID = -1
85# Pattern to remove all non alpha numeric from a string.
86_PATTERN = re.compile(r"[\W_]+")
87
88
89class InequalitySplitHandler(base_split_handler.BaseSplitHandler):
90  """Base class for handlers of inequality splits."""
91
92  def __init__(self,
93               l1_regularization,
94               l2_regularization,
95               tree_complexity_regularization,
96               min_node_weight,
97               feature_column_group_id,
98               epsilon,
99               num_quantiles,
100               gradient_shape,
101               hessian_shape,
102               multiclass_strategy,
103               init_stamp_token=0,
104               loss_uses_sum_reduction=False,
105               name=None):
106    """Initialize the internal state for this split handler.
107
108    Args:
109      l1_regularization: L1 regularization applied for this split handler.
110      l2_regularization: L2 regularization applied for this split handler.
111      tree_complexity_regularization: Tree complexity regularization applied
112          for this split handler.
113      min_node_weight: Minimum sum of weights of examples in each partition to
114          be considered for splitting.
115      feature_column_group_id: Feature column group index.
116      epsilon: A float, the error bound for quantile computation.
117      num_quantiles: An int, the number of buckets to create from the histogram.
118      gradient_shape: A TensorShape, containing shape of gradients.
119      hessian_shape: A TensorShape, containing shape of hessians.
120      multiclass_strategy: Strategy describing how to treat multiclass problems.
121      init_stamp_token: A tensor containing an scalar for initial stamp of the
122         stamped objects.
123      loss_uses_sum_reduction: A scalar boolean tensor that specifies whether
124          SUM or MEAN reduction was used for the loss.
125      name: An optional handler name.
126    """
127    super(InequalitySplitHandler, self).__init__(
128        name=name,
129        l1_regularization=l1_regularization,
130        l2_regularization=l2_regularization,
131        tree_complexity_regularization=tree_complexity_regularization,
132        min_node_weight=min_node_weight,
133        feature_column_group_id=feature_column_group_id,
134        gradient_shape=gradient_shape,
135        hessian_shape=hessian_shape,
136        multiclass_strategy=multiclass_strategy,
137        loss_uses_sum_reduction=loss_uses_sum_reduction)
138    self._stats_accumulator = stats_accumulator_ops.StatsAccumulator(
139        init_stamp_token,
140        gradient_shape,
141        hessian_shape,
142        name="StatsAccumulator/{}".format(self._name))
143    # Allocate both stats accumulator and quantile accumulator on the same
144    # device so that we can build splits with fewer RPCs.
145    with ops.colocate_with(self._stats_accumulator.resource_handle):
146      self._quantile_accumulator = quantile_ops.QuantileAccumulator(
147          init_stamp_token,
148          epsilon=epsilon,
149          num_quantiles=num_quantiles,
150          name="QuantileAccumulator/{}".format(self._name))
151
152  def reset(self, stamp_token, next_stamp_token):
153    reset_1 = self._stats_accumulator.flush(stamp_token, next_stamp_token)
154    reset_2 = self._quantile_accumulator.flush(stamp_token, next_stamp_token)
155    return control_flow_ops.group([reset_1, reset_2])
156
157
158class DenseSplitHandler(InequalitySplitHandler):
159  """Computes stats and finds the best inequality splits on dense columns."""
160
161  def __init__(self,
162               dense_float_column,
163               l1_regularization,
164               l2_regularization,
165               tree_complexity_regularization,
166               min_node_weight,
167               feature_column_group_id,
168               epsilon,
169               num_quantiles,
170               gradient_shape,
171               hessian_shape,
172               multiclass_strategy,
173               init_stamp_token=0,
174               loss_uses_sum_reduction=False,
175               weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE,
176               name=None):
177    """Initialize the internal state for this split handler.
178
179    Args:
180      dense_float_column: A `Tensor` column associated with this handler.
181      l1_regularization: L1 regularization applied for this split handler.
182      l2_regularization: L2 regularization applied for this split handler.
183      tree_complexity_regularization: Tree complexity regularization applied
184          for this split handler.
185      min_node_weight: Minimum sum of weights of examples in each partition to
186          be considered for splitting.
187      feature_column_group_id: Feature column group index.
188      epsilon: A float, the error bound for quantile computation.
189      num_quantiles: An int, the number of buckets to create from the histogram.
190      gradient_shape: A TensorShape, containing shape of gradients.
191      hessian_shape: A TensorShape, containing shape of hessians.
192      multiclass_strategy: Strategy describing how to treat multiclass problems.
193      init_stamp_token: A tensor containing an scalar for initial stamp of the
194         stamped objects.
195      loss_uses_sum_reduction: A scalar boolean tensor that specifies whether
196          SUM or MEAN reduction was used for the loss.
197      weak_learner_type: Specifies the type of weak learner to use.
198      name: An optional handler name.
199    """
200    super(DenseSplitHandler, self).__init__(
201        l1_regularization=l1_regularization,
202        l2_regularization=l2_regularization,
203        tree_complexity_regularization=tree_complexity_regularization,
204        min_node_weight=min_node_weight,
205        feature_column_group_id=feature_column_group_id,
206        epsilon=epsilon,
207        num_quantiles=num_quantiles,
208        init_stamp_token=init_stamp_token,
209        name=name,
210        gradient_shape=gradient_shape,
211        hessian_shape=hessian_shape,
212        multiclass_strategy=multiclass_strategy,
213        loss_uses_sum_reduction=loss_uses_sum_reduction)
214    self._dense_float_column = dense_float_column
215    self._weak_learner_type = weak_learner_type
216    # Register dense_make_stats_update function as an Op to the graph.
217    g = ops.get_default_graph()
218    dense_make_stats_update.add_to_graph(g)
219
220  def scheduled_reads(self):
221    return [self._quantile_accumulator.schedule_get_buckets()]
222
223  def update_stats(self, stamp_token, example_partition_ids, gradients,
224                   hessians, empty_gradients, empty_hessians, weights,
225                   is_active, scheduled_reads):
226    """Updates the state for dense split handler.
227
228    Args:
229      stamp_token: An int32 scalar tensor containing the current stamp token.
230      example_partition_ids: A dense tensor, containing an int32 for each
231        example which is the partition id that the example ends up in.
232      gradients: A dense tensor of gradients.
233      hessians: A dense tensor of hessians.
234      empty_gradients: A dense empty tensor of the same shape (for dimensions >
235        0) as gradients.
236      empty_hessians: A dense empty tensor of the same shape (for dimensions >
237        0) as hessians.
238      weights: A dense float32 tensor with a weight for each example.
239      is_active: A boolean tensor that says if this handler is active or not.
240          One value for the current layer and one value for the next layer.
241      scheduled_reads: List of scheduled reads for this handler.
242
243    Returns:
244      The op that updates the stats for this handler.
245    """
246    name = _PATTERN.sub("", self._name)
247    with ops.name_scope(name, "DenseSplitHandler"):
248      are_buckets_ready, buckets = scheduled_reads[0]
249      (quantile_values, quantile_weights, example_partition_ids,
250       feature_ids, gradients, hessians) = dense_make_stats_update(
251           is_active, are_buckets_ready, self._dense_float_column, buckets,
252           example_partition_ids, gradients, hessians, weights, empty_gradients,
253           empty_hessians)
254      update_quantiles = self._quantile_accumulator.schedule_add_summary(
255          stamp_token=stamp_token,
256          column=quantile_values,
257          example_weights=quantile_weights)
258      update_stats = self._stats_accumulator.schedule_add(
259          example_partition_ids, feature_ids, gradients, hessians)
260      return control_flow_ops.no_op(), [update_quantiles, update_stats]
261
262  def make_splits(self, stamp_token, next_stamp_token, class_id):
263    """Create the best split using the accumulated stats and flush the state."""
264    if (self._gradient_shape == tensor_shape.scalar() and
265        self._hessian_shape == tensor_shape.scalar()):
266      handler = make_dense_split_scalar
267    else:
268      handler = make_dense_split_tensor
269
270    are_splits_ready, partition_ids, gains, split_infos = (
271        handler(self._quantile_accumulator.resource_handle,
272                self._stats_accumulator.resource_handle, stamp_token,
273                next_stamp_token, self._multiclass_strategy, class_id,
274                self._feature_column_group_id, self._l1_regularization,
275                self._l2_regularization, self._tree_complexity_regularization,
276                self._min_node_weight, self._loss_uses_sum_reduction,
277                self._weak_learner_type))
278    return are_splits_ready, partition_ids, gains, split_infos
279
280
281def _make_dense_split(quantile_accumulator_handle, stats_accumulator_handle,
282                      stamp_token, next_stamp_token, multiclass_strategy,
283                      class_id, feature_column_id, l1_regularization,
284                      l2_regularization, tree_complexity_regularization,
285                      min_node_weight, is_multi_dimentional,
286                      loss_uses_sum_reduction, weak_learner_type):
287  """Function that builds splits for a dense feature column."""
288  # Get the bucket boundaries
289  are_splits_ready, buckets = (
290      gen_quantile_ops.quantile_accumulator_get_buckets(
291          quantile_accumulator_handles=[quantile_accumulator_handle],
292          stamp_token=stamp_token))
293  # quantile_accumulator_get_buckets returns a list of results per handle that
294  # we pass to it. In this case we're getting results just for one resource.
295  are_splits_ready = are_splits_ready[0]
296  buckets = buckets[0]
297
298  # After we receive the boundaries from previous iteration we can flush
299  # the quantile accumulator.
300  with ops.control_dependencies([buckets]):
301    flush_quantiles = gen_quantile_ops.quantile_accumulator_flush(
302        quantile_accumulator_handle=quantile_accumulator_handle,
303        stamp_token=stamp_token,
304        next_stamp_token=next_stamp_token)
305
306  if is_multi_dimentional:
307    num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
308        gen_stats_accumulator_ops.stats_accumulator_tensor_flush(
309            stats_accumulator_handle, stamp_token, next_stamp_token))
310  else:
311    num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
312        gen_stats_accumulator_ops.stats_accumulator_scalar_flush(
313            stats_accumulator_handle, stamp_token, next_stamp_token))
314  # For sum_reduction, we don't need to divide by number of minibatches.
315  num_minibatches = control_flow_ops.cond(
316      loss_uses_sum_reduction,
317      lambda: math_ops.cast(1, dtypes.int64),
318      lambda: num_minibatches)
319  # Put quantile and stats accumulator flushing in the dependency path.
320  with ops.control_dependencies([flush_quantiles, partition_ids]):
321    are_splits_ready = array_ops.identity(are_splits_ready)
322  partition_ids, gains, split_infos = (
323      split_handler_ops.build_dense_inequality_splits(
324          num_minibatches=num_minibatches,
325          bucket_boundaries=buckets,
326          partition_ids=partition_ids,
327          bucket_ids=bucket_ids,
328          gradients=gradients,
329          hessians=hessians,
330          class_id=class_id,
331          feature_column_group_id=feature_column_id,
332          l1_regularization=l1_regularization,
333          l2_regularization=l2_regularization,
334          tree_complexity_regularization=tree_complexity_regularization,
335          min_node_weight=min_node_weight,
336          multiclass_strategy=multiclass_strategy,
337          weak_learner_type=weak_learner_type))
338  return are_splits_ready, partition_ids, gains, split_infos
339
340
341class SparseSplitHandler(InequalitySplitHandler):
342  """Computes stats and finds the best inequality splits on sparse columns."""
343
344  def __init__(self,
345               sparse_float_column,
346               l1_regularization,
347               l2_regularization,
348               tree_complexity_regularization,
349               min_node_weight,
350               feature_column_group_id,
351               epsilon,
352               num_quantiles,
353               gradient_shape,
354               hessian_shape,
355               multiclass_strategy,
356               init_stamp_token=0,
357               loss_uses_sum_reduction=False,
358               name=None):
359    """Initialize the internal state for this split handler.
360
361    Args:
362      sparse_float_column: A `SparseTensor` column associated with this handler.
363      l1_regularization: L1 regularization applied for this split handler.
364      l2_regularization: L2 regularization applied for this split handler.
365      tree_complexity_regularization: Tree complexity regularization applied
366          for this split handler.
367      min_node_weight: Minimum sum of weights of examples in each partition to
368          be considered for splitting.
369      feature_column_group_id: Feature column group index.
370      epsilon: A float, the error bound for quantile computation.
371      num_quantiles: An int, the number of buckets to create from the histogram.
372      gradient_shape: A TensorShape, containing shape of gradients.
373      hessian_shape: A TensorShape, containing shape of hessians.
374      multiclass_strategy: Strategy describing how to treat multiclass problems.
375      init_stamp_token: A tensor containing an scalar for initial stamp of the
376         stamped objects.
377      loss_uses_sum_reduction: A scalar boolean tensor that specifies whether
378          SUM or MEAN reduction was used for the loss.
379      name: An optional handler name.
380    """
381    super(SparseSplitHandler, self).__init__(
382        l1_regularization=l1_regularization,
383        l2_regularization=l2_regularization,
384        tree_complexity_regularization=tree_complexity_regularization,
385        min_node_weight=min_node_weight,
386        feature_column_group_id=feature_column_group_id,
387        epsilon=epsilon,
388        num_quantiles=num_quantiles,
389        gradient_shape=gradient_shape,
390        hessian_shape=hessian_shape,
391        multiclass_strategy=multiclass_strategy,
392        init_stamp_token=init_stamp_token,
393        loss_uses_sum_reduction=loss_uses_sum_reduction,
394        name=name)
395    self._sparse_float_column = sparse_float_column
396
397  def scheduled_reads(self):
398    return [self._quantile_accumulator.schedule_get_buckets()]
399
400  def update_stats(self, stamp_token, example_partition_ids, gradients,
401                   hessians, empty_gradients, empty_hessians, weights,
402                   is_active, scheduled_reads):
403    """Updates the state for dense split handler.
404
405    Args:
406      stamp_token: An int32 scalar tensor containing the current stamp token.
407      example_partition_ids: A dense tensor, containing an int32 for each
408        example which is the partition id that the example ends up in.
409      gradients: A dense tensor of gradients.
410      hessians: A dense tensor of hessians.
411      empty_gradients: A dense empty tensor of the same shape (for dimensions >
412        0) as gradients.
413      empty_hessians: A dense empty tensor of the same shape (for dimensions >
414        0) as hessians.
415      weights: A dense float32 tensor with a weight for each example.
416      is_active: A boolean tensor that says if this handler is active or not.
417          One value for the current layer and one value for the next layer.
418      scheduled_reads: List of results from the scheduled reads.
419
420    Returns:
421      The op that updates the stats for this handler.
422    """
423    are_buckets_ready, buckets = scheduled_reads[0]
424    with ops.name_scope(self._name, "SparseSplitHandler"):
425      (quantile_indices, quantile_values, quantile_shapes, quantile_weights,
426       example_partition_ids, feature_ids, gradients,
427       hessians) = sparse_make_stats_update(
428           is_active, are_buckets_ready, self._sparse_float_column.indices,
429           self._sparse_float_column.values,
430           self._sparse_float_column.dense_shape, buckets,
431           example_partition_ids, gradients, hessians, weights, empty_gradients,
432           empty_hessians)
433      update_quantiles = self._quantile_accumulator.schedule_add_summary(
434          stamp_token=stamp_token,
435          column=sparse_tensor.SparseTensor(quantile_indices, quantile_values,
436                                            quantile_shapes),
437          example_weights=quantile_weights)
438      update_stats = self._stats_accumulator.schedule_add(
439          example_partition_ids, feature_ids, gradients, hessians)
440      return (control_flow_ops.no_op(), [update_quantiles, update_stats])
441
442  def make_splits(self, stamp_token, next_stamp_token, class_id):
443    """Create the best split using the accumulated stats and flush the state."""
444    if (self._gradient_shape == tensor_shape.scalar() and
445        self._hessian_shape == tensor_shape.scalar()):
446      handler = make_sparse_split_scalar
447    else:
448      handler = make_sparse_split_tensor
449
450    are_splits_ready, partition_ids, gains, split_infos = (
451        handler(self._quantile_accumulator.resource_handle,
452                self._stats_accumulator.resource_handle, stamp_token,
453                next_stamp_token, self._multiclass_strategy, class_id,
454                self._feature_column_group_id, self._l1_regularization,
455                self._l2_regularization, self._tree_complexity_regularization,
456                self._min_node_weight, self._loss_uses_sum_reduction))
457    return are_splits_ready, partition_ids, gains, split_infos
458
459
460def _make_sparse_split(
461    quantile_accumulator_handle, stats_accumulator_handle, stamp_token,
462    next_stamp_token, multiclass_strategy, class_id, feature_column_id,
463    l1_regularization, l2_regularization, tree_complexity_regularization,
464    min_node_weight, is_multi_dimentional, loss_uses_sum_reduction):
465  """Function that builds splits for a sparse feature column."""
466  # Get the bucket boundaries
467  are_splits_ready, buckets = (
468      gen_quantile_ops.quantile_accumulator_get_buckets(
469          quantile_accumulator_handles=[quantile_accumulator_handle],
470          stamp_token=stamp_token))
471  # quantile_accumulator_get_buckets returns a list of results per handle that
472  # we pass to it. In this case we're getting results just for one resource.
473  are_splits_ready = are_splits_ready[0]
474  buckets = buckets[0]
475
476  # After we receive the boundaries from previous iteration we can flush
477  # the quantile accumulator.
478  with ops.control_dependencies([buckets]):
479    flush_quantiles = gen_quantile_ops.quantile_accumulator_flush(
480        quantile_accumulator_handle=quantile_accumulator_handle,
481        stamp_token=stamp_token,
482        next_stamp_token=next_stamp_token)
483
484  if is_multi_dimentional:
485    num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
486        gen_stats_accumulator_ops.stats_accumulator_tensor_flush(
487            stats_accumulator_handle, stamp_token, next_stamp_token))
488  else:
489    num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
490        gen_stats_accumulator_ops.stats_accumulator_scalar_flush(
491            stats_accumulator_handle, stamp_token, next_stamp_token))
492  num_minibatches = control_flow_ops.cond(
493      loss_uses_sum_reduction,
494      lambda: math_ops.cast(1, dtypes.int64),
495      lambda: num_minibatches)
496  # Put quantile and stats accumulator flushing in the dependency path.
497  with ops.control_dependencies([flush_quantiles, partition_ids]):
498    are_splits_ready = array_ops.identity(are_splits_ready)
499  partition_ids, gains, split_infos = (
500      split_handler_ops.build_sparse_inequality_splits(
501          num_minibatches=num_minibatches,
502          bucket_boundaries=buckets,
503          partition_ids=partition_ids,
504          bucket_ids=bucket_ids,
505          gradients=gradients,
506          hessians=hessians,
507          class_id=class_id,
508          feature_column_group_id=feature_column_id,
509          l1_regularization=l1_regularization,
510          l2_regularization=l2_regularization,
511          tree_complexity_regularization=tree_complexity_regularization,
512          min_node_weight=min_node_weight,
513          bias_feature_id=_BIAS_FEATURE_ID,
514          multiclass_strategy=multiclass_strategy))
515  return are_splits_ready, partition_ids, gains, split_infos
516
517
518def _specialize_make_split_dense(func, is_multi_dimentional):
519  """Builds a specialized version of the function."""
520
521  @function.Defun(
522      dtypes.resource,
523      dtypes.resource,
524      dtypes.int64,
525      dtypes.int64,
526      dtypes.int32,
527      dtypes.int32,
528      dtypes.int32,
529      dtypes.float32,
530      dtypes.float32,
531      dtypes.float32,
532      dtypes.float32,
533      dtypes.bool,
534      dtypes.int32,
535      noinline=True)
536  def f(quantile_accumulator_handle, stats_accumulator_handle, stamp_token,
537        next_stamp_token, multiclass_strategy, class_id, feature_column_id,
538        l1_regularization, l2_regularization, tree_complexity_regularization,
539        min_node_weight, loss_uses_sum_reduction, weak_learner_type):
540    """Function that builds splits for a sparse feature column."""
541    return func(quantile_accumulator_handle, stats_accumulator_handle,
542                stamp_token, next_stamp_token, multiclass_strategy, class_id,
543                feature_column_id, l1_regularization, l2_regularization,
544                tree_complexity_regularization, min_node_weight,
545                is_multi_dimentional, loss_uses_sum_reduction,
546                weak_learner_type)
547
548  return f
549
550
551def _specialize_make_split_sparse(func, is_multi_dimentional):
552  """Builds a specialized version of the function."""
553
554  @function.Defun(
555      dtypes.resource,
556      dtypes.resource,
557      dtypes.int64,
558      dtypes.int64,
559      dtypes.int32,
560      dtypes.int32,
561      dtypes.int32,
562      dtypes.float32,
563      dtypes.float32,
564      dtypes.float32,
565      dtypes.float32,
566      dtypes.bool,
567      noinline=True)
568  def f(quantile_accumulator_handle, stats_accumulator_handle, stamp_token,
569        next_stamp_token, multiclass_strategy, class_id, feature_column_id,
570        l1_regularization, l2_regularization, tree_complexity_regularization,
571        min_node_weight, loss_uses_sum_reduction):
572    """Function that builds splits for a sparse feature column."""
573    return func(quantile_accumulator_handle, stats_accumulator_handle,
574                stamp_token, next_stamp_token, multiclass_strategy, class_id,
575                feature_column_id, l1_regularization, l2_regularization,
576                tree_complexity_regularization, min_node_weight,
577                is_multi_dimentional, loss_uses_sum_reduction)
578
579  return f
580
581
582make_dense_split_scalar = _specialize_make_split_dense(
583    _make_dense_split, is_multi_dimentional=False)
584
585make_dense_split_tensor = _specialize_make_split_dense(
586    _make_dense_split, is_multi_dimentional=True)
587
588make_sparse_split_scalar = _specialize_make_split_sparse(
589    _make_sparse_split, is_multi_dimentional=False)
590make_sparse_split_tensor = _specialize_make_split_sparse(
591    _make_sparse_split, is_multi_dimentional=True)
592
593
594@function.Defun(
595    dtypes.bool,
596    dtypes.bool,
597    dtypes.float32,
598    dtypes.float32,
599    dtypes.int32,
600    dtypes.float32,
601    dtypes.float32,
602    dtypes.float32,
603    dtypes.float32,
604    dtypes.float32,
605    noinline=True)
606def dense_make_stats_update(is_active, are_buckets_ready, float_column,
607                            quantile_buckets, example_partition_ids, gradients,
608                            hessians, weights, empty_gradients, empty_hessians):
609  """Updates the state for dense split handler."""
610  empty_float = constant_op.constant_v1([], dtype=dtypes.float32)
611
612  quantile_values, quantile_weights = control_flow_ops.cond(
613      is_active[1],  # For the next layer, this handler is inactive.
614      lambda: (float_column, weights),
615      lambda: (empty_float, empty_float))
616
617  def ready_inputs_fn():
618    """Branch to execute when quantiles are ready."""
619    quantized_feature = quantile_ops.quantiles([float_column], [],
620                                               [quantile_buckets], [], [])
621    quantized_feature = math_ops.cast(quantized_feature[0], dtypes.int64)
622    quantized_feature = array_ops.squeeze(quantized_feature, axis=0)
623    return (example_partition_ids, quantized_feature, gradients, hessians)
624
625  def not_ready_inputs_fn():
626    return (constant_op.constant_v1([], dtype=dtypes.int32),
627            constant_op.constant_v1([[]], dtype=dtypes.int64, shape=[1, 2]),
628            empty_gradients, empty_hessians)
629
630  example_partition_ids, feature_ids, gradients, hessians = (
631      control_flow_ops.cond(
632          math_ops.logical_and(
633              math_ops.logical_and(are_buckets_ready,
634                                   array_ops.size(quantile_buckets) > 0),
635              is_active[0]), ready_inputs_fn, not_ready_inputs_fn))
636  return (quantile_values, quantile_weights, example_partition_ids, feature_ids,
637          gradients, hessians)
638
639
640@function.Defun(
641    dtypes.bool,
642    dtypes.bool,
643    dtypes.int64,
644    dtypes.float32,
645    dtypes.int64,
646    dtypes.float32,
647    dtypes.int32,
648    dtypes.float32,
649    dtypes.float32,
650    dtypes.float32,
651    dtypes.float32,
652    dtypes.float32,
653    noinline=True)
654def sparse_make_stats_update(
655    is_active, are_buckets_ready, sparse_column_indices, sparse_column_values,
656    sparse_column_shape, quantile_buckets, example_partition_ids, gradients,
657    hessians, weights, empty_gradients, empty_hessians):
658  """Updates the state for this split handler."""
659
660  def quantiles_ready():
661    """The subgraph for when the quantiles are ready."""
662    quantized_feature = quantile_ops.quantiles([], [sparse_column_values], [],
663                                               [quantile_buckets],
664                                               [sparse_column_indices])
665
666    quantized_feature = math_ops.cast(quantized_feature[1], dtypes.int64)
667    quantized_feature = array_ops.squeeze(quantized_feature, axis=0)
668
669    example_indices, _ = array_ops.split(
670        sparse_column_indices, num_or_size_splits=2, axis=1)
671    example_indices = array_ops.squeeze(example_indices, [1])
672    filtered_gradients = array_ops.gather(gradients, example_indices)
673    filtered_hessians = array_ops.gather(hessians, example_indices)
674    filtered_partition_ids = array_ops.gather(example_partition_ids,
675                                              example_indices)
676    unique_partitions, mapped_partitions = array_ops.unique(
677        example_partition_ids)
678
679    # Compute aggregate stats for each partition.
680    # Since unsorted_segment_sum can be numerically unstable, use 64bit
681    # operation.
682    gradients64 = math_ops.cast(gradients, dtypes.float64)
683    hessians64 = math_ops.cast(hessians, dtypes.float64)
684    per_partition_gradients = math_ops.unsorted_segment_sum(
685        gradients64, mapped_partitions, array_ops.size(unique_partitions))
686    per_partition_hessians = math_ops.unsorted_segment_sum(
687        hessians64, mapped_partitions, array_ops.size(unique_partitions))
688    per_partition_gradients = math_ops.cast(per_partition_gradients,
689                                            dtypes.float32)
690    per_partition_hessians = math_ops.cast(per_partition_hessians,
691                                           dtypes.float32)
692    # Prepend a bias feature per partition that accumulates the stats for all
693    # examples in that partition.
694    bias_feature_ids = array_ops.fill(
695        array_ops.shape(unique_partitions), _BIAS_FEATURE_ID)
696    bias_feature_ids = math_ops.cast(bias_feature_ids, dtypes.int64)
697    zeros = array_ops.zeros_like(bias_feature_ids)
698    bias_feature_ids = array_ops.stack([bias_feature_ids, zeros], axis=1)
699
700    partition_ids = array_ops.concat(
701        [unique_partitions, filtered_partition_ids], 0)
702    filtered_gradients = array_ops.concat(
703        [per_partition_gradients, filtered_gradients], 0)
704    filtered_hessians = array_ops.concat(
705        [per_partition_hessians, filtered_hessians], 0)
706
707    bucket_ids = array_ops.concat([bias_feature_ids, quantized_feature], 0)
708
709    return partition_ids, bucket_ids, filtered_gradients, filtered_hessians
710
711  def quantiles_not_ready():
712    """The subgraph for when the quantiles are not ready."""
713    return (constant_op.constant_v1([], dtype=dtypes.int32),
714            constant_op.constant_v1([], dtype=dtypes.int64, shape=[1, 2]),
715            empty_gradients, empty_hessians)
716
717  empty_float = constant_op.constant_v1([], dtype=dtypes.float32)
718  handler_not_active = (constant_op.constant(
719      [], dtype=dtypes.int64, shape=[0, 2]), empty_float,
720                        constant_op.constant([0, 1], dtype=dtypes.int64),
721                        empty_float)
722  handler_active = (sparse_column_indices, sparse_column_values,
723                    sparse_column_shape, weights)
724  quantile_indices, quantile_values, quantile_shape, quantile_weights = (
725      control_flow_ops.cond(is_active[1], lambda: handler_active,
726                            lambda: handler_not_active))
727
728  example_partition_ids, feature_ids, gradients, hessians = (
729      control_flow_ops.cond(
730          math_ops.logical_and(are_buckets_ready,
731                               array_ops.size(quantile_buckets) > 0),
732          quantiles_ready, quantiles_not_ready))
733
734  return (quantile_indices, quantile_values, quantile_shape, quantile_weights,
735          example_partition_ids, feature_ids, gradients, hessians)
736