1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Implementation of handler for split nodes for categorical columns.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.contrib.boosted_trees.lib.learner.batch import base_split_handler 22from tensorflow.contrib.boosted_trees.proto import learner_pb2 23from tensorflow.contrib.boosted_trees.python.ops import split_handler_ops 24from tensorflow.contrib.boosted_trees.python.ops import stats_accumulator_ops 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import control_flow_ops 30from tensorflow.python.ops import math_ops 31 32_BIAS_FEATURE_ID = int(dtypes.int64.min) 33 34 35class EqualitySplitHandler(base_split_handler.BaseSplitHandler): 36 """Creates equality split type for categorical features.""" 37 38 def __init__(self, 39 sparse_int_column, 40 l1_regularization, 41 l2_regularization, 42 tree_complexity_regularization, 43 min_node_weight, 44 feature_column_group_id, 45 gradient_shape, 46 hessian_shape, 47 multiclass_strategy, 48 init_stamp_token=0, 49 loss_uses_sum_reduction=False, 50 weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE, 51 name=None): 52 """Initialize the internal state for this split handler. 53 54 Args: 55 sparse_int_column: A `SparseTensor` column with int64 values associated 56 with this handler. 57 l1_regularization: L1 regularization applied for this split handler. 58 l2_regularization: L2 regularization applied for this split handler. 59 tree_complexity_regularization: Tree complexity regularization applied 60 for this split handler. 61 min_node_weight: Minimum sum of weights of examples in each partition to 62 be considered for splitting. 63 feature_column_group_id: Feature column group index. 64 gradient_shape: A TensorShape, containing shape of gradients. 65 hessian_shape: A TensorShape, containing shape of hessians. 66 multiclass_strategy: Strategy describing how to treat multiclass problems. 67 init_stamp_token: A tensor containing an scalar for initial stamp of the 68 stamped objects. 69 loss_uses_sum_reduction: A scalar boolean tensor that specifies whether 70 SUM or MEAN reduction was used for the loss. 71 weak_learner_type: Specifies the type of weak learner to use. 72 name: An optional handler name. 73 """ 74 super(EqualitySplitHandler, self).__init__( 75 l1_regularization=l1_regularization, 76 l2_regularization=l2_regularization, 77 tree_complexity_regularization=tree_complexity_regularization, 78 min_node_weight=min_node_weight, 79 feature_column_group_id=feature_column_group_id, 80 gradient_shape=gradient_shape, 81 hessian_shape=hessian_shape, 82 multiclass_strategy=multiclass_strategy, 83 loss_uses_sum_reduction=loss_uses_sum_reduction, 84 name=name) 85 self._stats_accumulator = stats_accumulator_ops.StatsAccumulator( 86 init_stamp_token, 87 gradient_shape, 88 hessian_shape, 89 name="StatsAccumulator/{}".format(self._name)) 90 self._sparse_int_column = sparse_int_column 91 self._weak_learner_type = weak_learner_type 92 93 def update_stats(self, stamp_token, example_partition_ids, gradients, 94 hessians, empty_gradients, empty_hessians, weights, 95 is_active, scheduled_reads): 96 """Updates the state for equality split handler. 97 98 Args: 99 stamp_token: An int32 scalar tensor containing the current stamp token. 100 example_partition_ids: A dense tensor, containing an int32 for each 101 example which is the partition id that the example ends up in. 102 gradients: A dense tensor of gradients. 103 hessians: A dense tensor of hessians. 104 empty_gradients: A dense empty tensor of the same shape (for dimensions > 105 0) as gradients. 106 empty_hessians: A dense empty tensor of the same shape (for dimensions > 107 0) as hessians. 108 weights: A dense float32 tensor with a weight for each example. 109 is_active: A boolean tensor that says if this handler is active or not. 110 One value for the current layer and one value for the next layer. 111 scheduled_reads: List of results from the scheduled reads. 112 Returns: 113 The op that updates the stats for this handler. 114 Raises: 115 ValueError: If example_columns is not a single sparse column. 116 117 """ 118 del scheduled_reads # Unused by the categorical split handler. 119 120 def not_active_inputs(): 121 return (constant_op.constant([], dtype=dtypes.int32), 122 constant_op.constant_v1([], dtype=dtypes.int64, shape=[1, 2]), 123 empty_gradients, empty_hessians) 124 125 def active_inputs(): 126 """The normal flow when the handler is active.""" 127 # Remove the second column of example indices matrix since it is not 128 # useful. 129 example_indices, _ = array_ops.split( 130 self._sparse_int_column.indices, num_or_size_splits=2, axis=1) 131 example_indices = array_ops.squeeze(example_indices, [1]) 132 133 filtered_gradients = array_ops.gather(gradients, example_indices) 134 filtered_hessians = array_ops.gather(hessians, example_indices) 135 filtered_partition_ids = array_ops.gather(example_partition_ids, 136 example_indices) 137 unique_partitions, mapped_partitions = array_ops.unique( 138 example_partition_ids) 139 140 # Compute aggregate stats for each partition. 141 # The bias is computed on gradients and hessians (and not 142 # filtered_gradients) which have exactly one value per example, so we 143 # don't double count a gradient in multivalent columns. 144 # Since unsorted_segment_sum can be numerically unstable, use 64bit 145 # operation. 146 gradients64 = math_ops.cast(gradients, dtypes.float64) 147 hessians64 = math_ops.cast(hessians, dtypes.float64) 148 per_partition_gradients = math_ops.unsorted_segment_sum( 149 gradients64, mapped_partitions, array_ops.size(unique_partitions)) 150 per_partition_hessians = math_ops.unsorted_segment_sum( 151 hessians64, mapped_partitions, array_ops.size(unique_partitions)) 152 per_partition_gradients = math_ops.cast(per_partition_gradients, 153 dtypes.float32) 154 per_partition_hessians = math_ops.cast(per_partition_hessians, 155 dtypes.float32) 156 # Prepend a bias feature per partition that accumulates the stats for all 157 # examples in that partition. 158 # Bias is added to the stats even if there are no examples with values in 159 # the current sparse column. The reason is that the other example batches 160 # might have values in these partitions so we have to keep the bias 161 # updated. 162 bias_feature_ids = array_ops.fill( 163 array_ops.shape(unique_partitions), _BIAS_FEATURE_ID) 164 bias_feature_ids = math_ops.cast(bias_feature_ids, dtypes.int64) 165 partition_ids = array_ops.concat( 166 [unique_partitions, filtered_partition_ids], 0) 167 filtered_gradients = array_ops.concat( 168 [per_partition_gradients, filtered_gradients], 0) 169 filtered_hessians = array_ops.concat( 170 [per_partition_hessians, filtered_hessians], 0) 171 feature_ids = array_ops.concat( 172 [bias_feature_ids, self._sparse_int_column.values], 0) 173 # Dimension is always zero for sparse int features. 174 dimension_ids = array_ops.zeros_like(feature_ids, dtype=dtypes.int64) 175 feature_ids_and_dimensions = array_ops.stack( 176 [feature_ids, dimension_ids], axis=1) 177 return (partition_ids, feature_ids_and_dimensions, filtered_gradients, 178 filtered_hessians) 179 180 partition_ids, feature_ids, gradients_out, hessians_out = ( 181 control_flow_ops.cond(is_active[0], active_inputs, not_active_inputs)) 182 result = self._stats_accumulator.schedule_add(partition_ids, feature_ids, 183 gradients_out, hessians_out) 184 return (control_flow_ops.no_op(), [result]) 185 186 def make_splits(self, stamp_token, next_stamp_token, class_id): 187 """Create the best split using the accumulated stats and flush the state.""" 188 # Get the aggregated gradients and hessians per <partition_id, feature_id> 189 # pair. 190 num_minibatches, partition_ids, feature_ids, gradients, hessians = ( 191 self._stats_accumulator.flush(stamp_token, next_stamp_token)) 192 # For sum_reduction, we don't need to divide by number of minibatches. 193 194 num_minibatches = control_flow_ops.cond( 195 ops.convert_to_tensor(self._loss_uses_sum_reduction), 196 lambda: math_ops.cast(1, dtypes.int64), 197 lambda: num_minibatches) 198 partition_ids, gains, split_infos = ( 199 split_handler_ops.build_categorical_equality_splits( 200 num_minibatches=num_minibatches, 201 partition_ids=partition_ids, 202 feature_ids=feature_ids, 203 gradients=gradients, 204 hessians=hessians, 205 class_id=class_id, 206 feature_column_group_id=self._feature_column_group_id, 207 l1_regularization=self._l1_regularization, 208 l2_regularization=self._l2_regularization, 209 tree_complexity_regularization=self._tree_complexity_regularization, 210 min_node_weight=self._min_node_weight, 211 bias_feature_id=_BIAS_FEATURE_ID, 212 multiclass_strategy=self._multiclass_strategy, 213 weak_learner_type=self._weak_learner_type)) 214 # There are no warm-up rounds needed in the equality column handler. So we 215 # always return ready. 216 are_splits_ready = constant_op.constant(True) 217 return (are_splits_ready, partition_ids, gains, split_infos) 218 219 def reset(self, stamp_token, next_stamp_token): 220 reset = self._stats_accumulator.flush(stamp_token, next_stamp_token) 221 return reset 222