• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implementation of handler for split nodes for categorical columns."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.contrib.boosted_trees.lib.learner.batch import base_split_handler
22from tensorflow.contrib.boosted_trees.proto import learner_pb2
23from tensorflow.contrib.boosted_trees.python.ops import split_handler_ops
24from tensorflow.contrib.boosted_trees.python.ops import stats_accumulator_ops
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import control_flow_ops
30from tensorflow.python.ops import math_ops
31
32_BIAS_FEATURE_ID = int(dtypes.int64.min)
33
34
35class EqualitySplitHandler(base_split_handler.BaseSplitHandler):
36  """Creates equality split type for categorical features."""
37
38  def __init__(self,
39               sparse_int_column,
40               l1_regularization,
41               l2_regularization,
42               tree_complexity_regularization,
43               min_node_weight,
44               feature_column_group_id,
45               gradient_shape,
46               hessian_shape,
47               multiclass_strategy,
48               init_stamp_token=0,
49               loss_uses_sum_reduction=False,
50               weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE,
51               name=None):
52    """Initialize the internal state for this split handler.
53
54    Args:
55      sparse_int_column: A `SparseTensor` column with int64 values associated
56        with this handler.
57      l1_regularization: L1 regularization applied for this split handler.
58      l2_regularization: L2 regularization applied for this split handler.
59      tree_complexity_regularization: Tree complexity regularization applied
60          for this split handler.
61      min_node_weight: Minimum sum of weights of examples in each partition to
62          be considered for splitting.
63      feature_column_group_id: Feature column group index.
64      gradient_shape: A TensorShape, containing shape of gradients.
65      hessian_shape: A TensorShape, containing shape of hessians.
66      multiclass_strategy: Strategy describing how to treat multiclass problems.
67      init_stamp_token: A tensor containing an scalar for initial stamp of the
68         stamped objects.
69      loss_uses_sum_reduction: A scalar boolean tensor that specifies whether
70          SUM or MEAN reduction was used for the loss.
71      weak_learner_type: Specifies the type of weak learner to use.
72      name: An optional handler name.
73    """
74    super(EqualitySplitHandler, self).__init__(
75        l1_regularization=l1_regularization,
76        l2_regularization=l2_regularization,
77        tree_complexity_regularization=tree_complexity_regularization,
78        min_node_weight=min_node_weight,
79        feature_column_group_id=feature_column_group_id,
80        gradient_shape=gradient_shape,
81        hessian_shape=hessian_shape,
82        multiclass_strategy=multiclass_strategy,
83        loss_uses_sum_reduction=loss_uses_sum_reduction,
84        name=name)
85    self._stats_accumulator = stats_accumulator_ops.StatsAccumulator(
86        init_stamp_token,
87        gradient_shape,
88        hessian_shape,
89        name="StatsAccumulator/{}".format(self._name))
90    self._sparse_int_column = sparse_int_column
91    self._weak_learner_type = weak_learner_type
92
93  def update_stats(self, stamp_token, example_partition_ids, gradients,
94                   hessians, empty_gradients, empty_hessians, weights,
95                   is_active, scheduled_reads):
96    """Updates the state for equality split handler.
97
98    Args:
99      stamp_token: An int32 scalar tensor containing the current stamp token.
100      example_partition_ids: A dense tensor, containing an int32 for each
101        example which is the partition id that the example ends up in.
102      gradients: A dense tensor of gradients.
103      hessians: A dense tensor of hessians.
104      empty_gradients: A dense empty tensor of the same shape (for dimensions >
105        0) as gradients.
106      empty_hessians: A dense empty tensor of the same shape (for dimensions >
107        0) as hessians.
108      weights: A dense float32 tensor with a weight for each example.
109      is_active: A boolean tensor that says if this handler is active or not.
110          One value for the current layer and one value for the next layer.
111      scheduled_reads: List of results from the scheduled reads.
112    Returns:
113      The op that updates the stats for this handler.
114    Raises:
115      ValueError: If example_columns is not a single sparse column.
116
117    """
118    del scheduled_reads  # Unused by the categorical split handler.
119
120    def not_active_inputs():
121      return (constant_op.constant([], dtype=dtypes.int32),
122              constant_op.constant_v1([], dtype=dtypes.int64, shape=[1, 2]),
123              empty_gradients, empty_hessians)
124
125    def active_inputs():
126      """The normal flow when the handler is active."""
127      # Remove the second column of example indices matrix since it is not
128      # useful.
129      example_indices, _ = array_ops.split(
130          self._sparse_int_column.indices, num_or_size_splits=2, axis=1)
131      example_indices = array_ops.squeeze(example_indices, [1])
132
133      filtered_gradients = array_ops.gather(gradients, example_indices)
134      filtered_hessians = array_ops.gather(hessians, example_indices)
135      filtered_partition_ids = array_ops.gather(example_partition_ids,
136                                                example_indices)
137      unique_partitions, mapped_partitions = array_ops.unique(
138          example_partition_ids)
139
140      # Compute aggregate stats for each partition.
141      # The bias is computed on gradients and hessians (and not
142      # filtered_gradients) which have exactly one value per example, so we
143      # don't double count a gradient in multivalent columns.
144      # Since unsorted_segment_sum can be numerically unstable, use 64bit
145      # operation.
146      gradients64 = math_ops.cast(gradients, dtypes.float64)
147      hessians64 = math_ops.cast(hessians, dtypes.float64)
148      per_partition_gradients = math_ops.unsorted_segment_sum(
149          gradients64, mapped_partitions, array_ops.size(unique_partitions))
150      per_partition_hessians = math_ops.unsorted_segment_sum(
151          hessians64, mapped_partitions, array_ops.size(unique_partitions))
152      per_partition_gradients = math_ops.cast(per_partition_gradients,
153                                              dtypes.float32)
154      per_partition_hessians = math_ops.cast(per_partition_hessians,
155                                             dtypes.float32)
156      # Prepend a bias feature per partition that accumulates the stats for all
157      # examples in that partition.
158      # Bias is added to the stats even if there are no examples with values in
159      # the current sparse column. The reason is that the other example batches
160      # might have values in these partitions so we have to keep the bias
161      # updated.
162      bias_feature_ids = array_ops.fill(
163          array_ops.shape(unique_partitions), _BIAS_FEATURE_ID)
164      bias_feature_ids = math_ops.cast(bias_feature_ids, dtypes.int64)
165      partition_ids = array_ops.concat(
166          [unique_partitions, filtered_partition_ids], 0)
167      filtered_gradients = array_ops.concat(
168          [per_partition_gradients, filtered_gradients], 0)
169      filtered_hessians = array_ops.concat(
170          [per_partition_hessians, filtered_hessians], 0)
171      feature_ids = array_ops.concat(
172          [bias_feature_ids, self._sparse_int_column.values], 0)
173      # Dimension is always zero for sparse int features.
174      dimension_ids = array_ops.zeros_like(feature_ids, dtype=dtypes.int64)
175      feature_ids_and_dimensions = array_ops.stack(
176          [feature_ids, dimension_ids], axis=1)
177      return (partition_ids, feature_ids_and_dimensions, filtered_gradients,
178              filtered_hessians)
179
180    partition_ids, feature_ids, gradients_out, hessians_out = (
181        control_flow_ops.cond(is_active[0], active_inputs, not_active_inputs))
182    result = self._stats_accumulator.schedule_add(partition_ids, feature_ids,
183                                                  gradients_out, hessians_out)
184    return (control_flow_ops.no_op(), [result])
185
186  def make_splits(self, stamp_token, next_stamp_token, class_id):
187    """Create the best split using the accumulated stats and flush the state."""
188    # Get the aggregated gradients and hessians per <partition_id, feature_id>
189    # pair.
190    num_minibatches, partition_ids, feature_ids, gradients, hessians = (
191        self._stats_accumulator.flush(stamp_token, next_stamp_token))
192    # For sum_reduction, we don't need to divide by number of minibatches.
193
194    num_minibatches = control_flow_ops.cond(
195        ops.convert_to_tensor(self._loss_uses_sum_reduction),
196        lambda: math_ops.cast(1, dtypes.int64),
197        lambda: num_minibatches)
198    partition_ids, gains, split_infos = (
199        split_handler_ops.build_categorical_equality_splits(
200            num_minibatches=num_minibatches,
201            partition_ids=partition_ids,
202            feature_ids=feature_ids,
203            gradients=gradients,
204            hessians=hessians,
205            class_id=class_id,
206            feature_column_group_id=self._feature_column_group_id,
207            l1_regularization=self._l1_regularization,
208            l2_regularization=self._l2_regularization,
209            tree_complexity_regularization=self._tree_complexity_regularization,
210            min_node_weight=self._min_node_weight,
211            bias_feature_id=_BIAS_FEATURE_ID,
212            multiclass_strategy=self._multiclass_strategy,
213            weak_learner_type=self._weak_learner_type))
214    # There are no warm-up rounds needed in the equality column handler. So we
215    # always return ready.
216    are_splits_ready = constant_op.constant(True)
217    return (are_splits_ready, partition_ids, gains, split_infos)
218
219  def reset(self, stamp_token, next_stamp_token):
220    reset = self._stats_accumulator.flush(stamp_token, next_stamp_token)
221    return reset
222