1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Training functions for Gradient boosted decision trees.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import copy 23 24from tensorflow.contrib import learn 25from tensorflow.contrib.boosted_trees.lib.learner.batch import categorical_split_handler 26from tensorflow.contrib.boosted_trees.lib.learner.batch import ordinal_split_handler 27from tensorflow.contrib.boosted_trees.proto import learner_pb2 28from tensorflow.contrib.boosted_trees.python.ops import batch_ops_utils 29from tensorflow.contrib.boosted_trees.python.ops import gen_model_ops 30from tensorflow.contrib.boosted_trees.python.ops import model_ops 31from tensorflow.contrib.boosted_trees.python.ops import prediction_ops 32from tensorflow.contrib.boosted_trees.python.ops import stats_accumulator_ops 33from tensorflow.contrib.boosted_trees.python.ops import training_ops 34from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib 35from tensorflow.contrib.layers.python.layers import feature_column_ops 36from tensorflow.python.feature_column import feature_column as fc_core 37from tensorflow.python.framework import constant_op 38from tensorflow.python.framework import dtypes 39from tensorflow.python.framework import ops 40from tensorflow.python.framework import sparse_tensor 41from tensorflow.python.framework import tensor_shape 42from tensorflow.python.ops import array_ops 43from tensorflow.python.ops import control_flow_ops 44from tensorflow.python.ops import gradients_impl 45from tensorflow.python.ops import math_ops 46from tensorflow.python.ops import stateless_random_ops as stateless 47from tensorflow.python.ops import variable_scope 48from tensorflow.python.ops import variables 49from tensorflow.python.ops.losses import losses 50from tensorflow.python.platform import tf_logging as logging 51from tensorflow.python.summary import summary 52from tensorflow.python.training import device_setter 53 54 55# Key names for prediction dict. 56ENSEMBLE_STAMP = "ensemble_stamp" 57PREDICTIONS = "predictions" 58PARTITION_IDS = "partition_ids" 59NUM_LAYERS_ATTEMPTED = "num_layers" 60NUM_TREES_ATTEMPTED = "num_trees" 61NUM_USED_HANDLERS = "num_used_handlers" 62USED_HANDLERS_MASK = "used_handlers_mask" 63LEAF_INDEX = "leaf_index" 64_FEATURE_NAME_TEMPLATE = "%s_%d" 65 66# Keys in Training state. 67GBDTTrainingState = collections.namedtuple("GBDTTrainingState", [ 68 "num_layer_examples", "num_layer_steps", "num_layers", "active_tree", 69 "active_layer", "continue_centering", "bias_stats_accumulator", 70 "steps_accumulator", "handlers" 71]) 72 73 74def _get_column_by_index(tensor, indices): 75 """Returns columns from a 2-D tensor by index.""" 76 shape = array_ops.shape(tensor) 77 p_flat = array_ops.reshape(tensor, [-1]) 78 i_flat = array_ops.reshape( 79 array_ops.reshape(math_ops.range(0, shape[0]) * shape[1], [-1, 1]) + 80 indices, [-1]) 81 return array_ops.reshape(array_ops.gather(p_flat, i_flat), [shape[0], -1]) 82 83 84def _make_predictions_dict(stamp, 85 logits, 86 partition_ids, 87 ensemble_stats, 88 used_handlers, 89 leaf_index=None): 90 """Returns predictions for the given logits and n_classes. 91 92 Args: 93 stamp: The ensemble stamp. 94 logits: A rank 2 `Tensor` with shape [batch_size, n_classes - 1]. that 95 contains predictions when no dropout was applied. 96 partition_ids: A rank 1 `Tensor` with shape [batch_size]. 97 ensemble_stats: A TreeEnsembleStatsOp result tuple. 98 used_handlers: A TreeEnsembleUsedHandlerOp result tuple of an int and a 99 boolean mask. 100 leaf_index: A rank 2 `Tensor` with shape [batch_size, number of trees]. that 101 contains leaf id for each example prediction. 102 103 Returns: 104 A dict of predictions. 105 """ 106 result = {} 107 result[ENSEMBLE_STAMP] = stamp 108 result[PREDICTIONS] = logits 109 result[PARTITION_IDS] = partition_ids 110 result[NUM_LAYERS_ATTEMPTED] = ensemble_stats.attempted_layers 111 result[NUM_TREES_ATTEMPTED] = ensemble_stats.attempted_trees 112 result[NUM_USED_HANDLERS] = used_handlers.num_used_handlers 113 result[USED_HANDLERS_MASK] = used_handlers.used_handlers_mask 114 if leaf_index is not None: 115 result[LEAF_INDEX] = leaf_index 116 return result 117 118 119class _OpRoundRobinStrategy(object): 120 """Returns the next ps task index for placement via per-Op round-robin order. 121 122 This strategy works slightly better for the GBDT graph because of using 123 custom resources which vary significantly in compute cost. 124 """ 125 126 def __init__(self, ps_ops, num_tasks): 127 """Create a new `_RoundRobinStrategy`. 128 129 Args: 130 ps_ops: List of Op types to place on PS. 131 num_tasks: Number of ps tasks to cycle among. 132 """ 133 next_task = 0 134 self._next_task_per_op = {} 135 for op in ps_ops: 136 self._next_task_per_op[op] = next_task 137 next_task = (next_task + 1) % num_tasks if num_tasks else 0 138 self._num_tasks = num_tasks 139 140 def __call__(self, op): 141 """Choose a ps task index for the given `Operation`. 142 143 Args: 144 op: An `Operation` to be placed on ps. 145 146 Returns: 147 The next ps task index to use for the `Operation`. Returns the next 148 index, in the range `[offset, offset + num_tasks)`. 149 150 Raises: 151 ValueError: If attempting to place non-PS Op. 152 """ 153 if op.type not in self._next_task_per_op: 154 raise ValueError("Unknown op type '%s' for placement:" % op.type) 155 task = self._next_task_per_op[op.type] 156 self._next_task_per_op[op.type] = ((task + 1) % self._num_tasks 157 if self._num_tasks else 0) 158 return task 159 160 161def extract_features(features, feature_columns, use_core_columns): 162 """Extracts columns from a dictionary of features. 163 164 Args: 165 features: `dict` of `Tensor` objects. 166 feature_columns: A list of feature_columns. 167 168 Returns: 169 Seven values: 170 - A list of all feature column names. 171 - A list of dense floats. 172 - A list of sparse float feature indices. 173 - A list of sparse float feature values. 174 - A list of sparse float feature shapes. 175 - A list of sparse int feature indices. 176 - A list of sparse int feature values. 177 - A list of sparse int feature shapes. 178 Raises: 179 ValueError: if features is not valid. 180 """ 181 if not features: 182 raise ValueError("Features dictionary must be specified.") 183 184 # Make a shallow copy of features to ensure downstream usage 185 # is unaffected by modifications in the model function. 186 features = copy.copy(features) 187 if feature_columns: 188 scope = "gbdt" 189 with variable_scope.variable_scope(scope): 190 feature_columns = list(feature_columns) 191 transformed_features = collections.OrderedDict() 192 for fc in feature_columns: 193 # pylint: disable=protected-access 194 if use_core_columns: 195 # pylint: disable=protected-access 196 tensor = fc_core._transform_features(features, [fc])[fc] 197 transformed_features[fc.name] = tensor 198 elif isinstance(fc, feature_column_lib._EmbeddingColumn): 199 # pylint: enable=protected-access 200 transformed_features[fc.name] = fc_core.input_layer( 201 features, [fc], weight_collections=[scope]) 202 else: 203 result = feature_column_ops.transform_features(features, [fc]) 204 if len(result) > 1: 205 raise ValueError("Unexpected number of output features") 206 transformed_features[fc.name] = result[list(result.keys())[0]] 207 features = transformed_features 208 209 dense_float_names = [] 210 dense_floats = [] 211 sparse_float_names = [] 212 sparse_float_indices = [] 213 sparse_float_values = [] 214 sparse_float_shapes = [] 215 sparse_int_names = [] 216 sparse_int_indices = [] 217 sparse_int_values = [] 218 sparse_int_shapes = [] 219 for key in sorted(features.keys()): 220 tensor = features[key] 221 # TODO(nponomareva): consider iterating over feature columns instead. 222 if isinstance(tensor, tuple): 223 # Weighted categorical feature. 224 categorical_tensor = tensor[0] 225 weight_tensor = tensor[1] 226 227 shape = categorical_tensor.dense_shape 228 indices = array_ops.concat([ 229 array_ops.slice(categorical_tensor.indices, [0, 0], [-1, 1]), 230 array_ops.expand_dims( 231 math_ops.cast(categorical_tensor.values, dtypes.int64), -1) 232 ], 1) 233 tensor = sparse_tensor.SparseTensor( 234 indices=indices, values=weight_tensor.values, dense_shape=shape) 235 236 if isinstance(tensor, sparse_tensor.SparseTensor): 237 if tensor.values.dtype == dtypes.float32: 238 sparse_float_names.append(key) 239 sparse_float_indices.append(tensor.indices) 240 sparse_float_values.append(tensor.values) 241 sparse_float_shapes.append(tensor.dense_shape) 242 elif tensor.values.dtype == dtypes.int64: 243 sparse_int_names.append(key) 244 sparse_int_indices.append(tensor.indices) 245 sparse_int_values.append(tensor.values) 246 sparse_int_shapes.append(tensor.dense_shape) 247 else: 248 raise ValueError("Unsupported sparse feature %s with dtype %s." % 249 (tensor.indices.name, tensor.dtype)) 250 else: 251 if tensor.dtype == dtypes.float32: 252 if len(tensor.shape) > 1 and tensor.shape[1] > 1: 253 unstacked = array_ops.unstack(tensor, axis=1) 254 for i in range(len(unstacked)): 255 dense_float_names.append(_FEATURE_NAME_TEMPLATE % (key, i)) 256 dense_floats.append(array_ops.reshape(unstacked[i], [-1, 1])) 257 else: 258 dense_float_names.append(key) 259 dense_floats.append(tensor) 260 else: 261 raise ValueError("Unsupported dense feature %s with dtype %s." % 262 (tensor.name, tensor.dtype)) 263 # Feature columns are logically organized into incrementing slots starting 264 # from dense floats, then sparse floats then sparse ints. 265 fc_names = (dense_float_names + sparse_float_names + sparse_int_names) 266 return (fc_names, dense_floats, sparse_float_indices, sparse_float_values, 267 sparse_float_shapes, sparse_int_indices, sparse_int_values, 268 sparse_int_shapes) 269 270 271def _dropout_params(mode, ensemble_stats): 272 """Returns parameters relevant for dropout. 273 274 Args: 275 mode: Train/Eval/Infer 276 ensemble_stats: A TreeEnsembleStatsOp result tuple. 277 278 Returns: 279 Whether to apply dropout and a dropout seed. 280 """ 281 if mode == learn.ModeKeys.TRAIN: 282 # Do dropout only during training. 283 apply_dropout = True 284 seed = ensemble_stats.attempted_trees 285 else: 286 seed = -1 287 apply_dropout = False 288 return apply_dropout, seed 289 290 291class GradientBoostedDecisionTreeModel(object): 292 """A GBDT model function.""" 293 294 def __init__(self, 295 is_chief, 296 num_ps_replicas, 297 ensemble_handle, 298 center_bias, 299 examples_per_layer, 300 learner_config, 301 features, 302 logits_dimension, 303 loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS, 304 feature_columns=None, 305 use_core_columns=False, 306 output_leaf_index=False, 307 output_leaf_index_modes=None, 308 num_quantiles=100): 309 """Construct a new GradientBoostedDecisionTreeModel function. 310 311 Args: 312 is_chief: Whether to build the chief graph. 313 num_ps_replicas: Number of parameter server replicas, can be 0. 314 ensemble_handle: A handle to the ensemble variable. 315 center_bias: Whether to center the bias before growing trees. 316 examples_per_layer: Number of examples to accumulate before growing a tree 317 layer. It can also be a function that computes the number of examples 318 based on the depth of the layer that's being built. 319 learner_config: A learner config. 320 features: `dict` of `Tensor` objects. 321 logits_dimension: An int, the dimension of logits. 322 loss_reduction: Either `SUM_OVER_NONZERO_WEIGHTS` (mean) or `SUM`. 323 feature_columns: A list of feature columns. 324 use_core_columns: A boolean specifying whether core feature columns are 325 used. 326 output_leaf_index: A boolean variable indicating whether to output leaf 327 index into predictions dictionary. 328 output_leaf_index_modes: A list of modes from (TRAIN, EVAL, INFER) which 329 dictates when leaf indices will be outputted. By default, leaf indices 330 are only outputted in INFER mode. 331 num_quantiles: Number of quantiles to build for numeric feature values. 332 333 Raises: 334 ValueError: if inputs are not valid. 335 """ 336 if ensemble_handle is None: 337 raise ValueError("ensemble_handle must be specified.") 338 339 if learner_config is None: 340 raise ValueError("learner_config must be specified.") 341 342 if learner_config.num_classes < 2: 343 raise ValueError("Number of classes must be >=2") 344 345 self._logits_dimension = logits_dimension 346 self._is_chief = is_chief 347 self._num_ps_replicas = num_ps_replicas 348 self._ensemble_handle = ensemble_handle 349 self._center_bias = center_bias 350 self._examples_per_layer = examples_per_layer 351 352 # Check loss reduction value. 353 if (loss_reduction != losses.Reduction.SUM and 354 loss_reduction != losses.Reduction.SUM_OVER_NONZERO_WEIGHTS): 355 raise ValueError( 356 "Invalid loss reduction is provided: %s." % loss_reduction) 357 self._loss_reduction = loss_reduction 358 359 # Fill in the defaults. 360 if (learner_config.multi_class_strategy == 361 learner_pb2.LearnerConfig.MULTI_CLASS_STRATEGY_UNSPECIFIED): 362 if logits_dimension == 1: 363 learner_config.multi_class_strategy = ( 364 learner_pb2.LearnerConfig.TREE_PER_CLASS) 365 else: 366 learner_config.multi_class_strategy = ( 367 learner_pb2.LearnerConfig.DIAGONAL_HESSIAN) 368 369 if logits_dimension == 1 or learner_config.multi_class_strategy == ( 370 learner_pb2.LearnerConfig.TREE_PER_CLASS): 371 self._gradient_shape = tensor_shape.scalar() 372 self._hessian_shape = tensor_shape.scalar() 373 else: 374 if center_bias: 375 raise ValueError("Center bias should be False for multiclass.") 376 377 self._gradient_shape = tensor_shape.TensorShape([logits_dimension]) 378 if (learner_config.multi_class_strategy == 379 learner_pb2.LearnerConfig.FULL_HESSIAN): 380 self._hessian_shape = tensor_shape.TensorShape( 381 ([logits_dimension, logits_dimension])) 382 else: 383 # Diagonal hessian strategy. 384 self._hessian_shape = tensor_shape.TensorShape(([logits_dimension])) 385 if (learner_config.growing_mode == 386 learner_pb2.LearnerConfig.GROWING_MODE_UNSPECIFIED): 387 learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER 388 389 if (learner_config.weak_learner_type == learner_pb2.LearnerConfig 390 .OBLIVIOUS_DECISION_TREE and learner_config.pruning_mode == learner_pb2 391 .LearnerConfig.PRUNING_MODE_UNSPECIFIED): 392 learner_config.pruning_mode = learner_pb2.LearnerConfig.PRE_PRUNE 393 394 if (learner_config.pruning_mode == 395 learner_pb2.LearnerConfig.PRUNING_MODE_UNSPECIFIED): 396 learner_config.pruning_mode = learner_pb2.LearnerConfig.POST_PRUNE 397 398 if (learner_config.weak_learner_type == learner_pb2.LearnerConfig 399 .OBLIVIOUS_DECISION_TREE and 400 learner_config.pruning_mode == learner_pb2.LearnerConfig.POST_PRUNE): 401 raise ValueError( 402 "Post pruning is not implmented for oblivious decision trees.") 403 404 if learner_config.constraints.max_tree_depth == 0: 405 # Use 6 as the default maximum depth. 406 learner_config.constraints.max_tree_depth = 6 407 408 tuner = learner_config.learning_rate_tuner.WhichOneof("tuner") 409 if not tuner: 410 learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 411 412 self._learner_config = learner_config 413 self._feature_columns = feature_columns 414 self._learner_config_serialized = learner_config.SerializeToString() 415 self._num_quantiles = num_quantiles 416 self._max_tree_depth = variables.VariableV1( 417 initial_value=self._learner_config.constraints.max_tree_depth) 418 self._attempted_trees = variables.VariableV1( 419 initial_value=array_ops.zeros([], dtypes.int64), 420 trainable=False, 421 name="attempted_trees") 422 self._finalized_trees = variables.VariableV1( 423 initial_value=array_ops.zeros([], dtypes.int64), 424 trainable=False, 425 name="finalized_trees") 426 if not features: 427 raise ValueError("Features dictionary must be specified.") 428 (fc_names, dense_floats, sparse_float_indices, sparse_float_values, 429 sparse_float_shapes, sparse_int_indices, 430 sparse_int_values, sparse_int_shapes) = extract_features( 431 features, self._feature_columns, use_core_columns) 432 if (learner_config.weak_learner_type == learner_pb2.LearnerConfig 433 .OBLIVIOUS_DECISION_TREE and sparse_float_indices): 434 raise ValueError("Oblivious trees don't handle sparse float features yet." 435 ) 436 437 logging.info("Active Feature Columns: " + str(fc_names)) 438 logging.info("Learner config: " + str(learner_config)) 439 self._fc_names = fc_names 440 self._dense_floats = dense_floats 441 self._sparse_float_indices = sparse_float_indices 442 self._sparse_float_values = sparse_float_values 443 self._sparse_float_shapes = sparse_float_shapes 444 self._sparse_int_indices = sparse_int_indices 445 self._sparse_int_values = sparse_int_values 446 self._sparse_int_shapes = sparse_int_shapes 447 self._reduce_dim = ( 448 self._learner_config.multi_class_strategy == 449 learner_pb2.LearnerConfig.TREE_PER_CLASS and 450 learner_config.num_classes == 2) 451 452 if output_leaf_index_modes is None: 453 output_leaf_index_modes = [learn.ModeKeys.INFER] 454 elif not all( 455 mode in (learn.ModeKeys.TRAIN, learn.ModeKeys.EVAL, 456 learn.ModeKeys.INFER) for mode in output_leaf_index_modes): 457 raise ValueError("output_leaf_index_modes should only contain ModeKeys.") 458 459 self._output_leaf_index = output_leaf_index 460 self._output_leaf_index_modes = output_leaf_index_modes 461 462 def _predict_and_return_dict(self, ensemble_handle, ensemble_stamp, mode): 463 """Runs prediction and returns a dictionary of the prediction results. 464 465 Args: 466 ensemble_handle: ensemble resource handle. 467 ensemble_stamp: stamp of ensemble resource. 468 mode: learn.ModeKeys.TRAIN or EVAL or INFER. 469 470 Returns: 471 a dictionary of prediction results - 472 ENSEMBLE_STAMP, PREDICTION, PARTITION_IDS, 473 NUM_LAYER_ATTEMPTED, NUM_TREES_ATTEMPTED. 474 """ 475 ensemble_stats = training_ops.tree_ensemble_stats(ensemble_handle, 476 ensemble_stamp) 477 num_handlers = ( 478 len(self._dense_floats) + len(self._sparse_float_shapes) + len( 479 self._sparse_int_shapes)) 480 # Used during feature selection. 481 used_handlers = model_ops.tree_ensemble_used_handlers( 482 ensemble_handle, ensemble_stamp, num_all_handlers=num_handlers) 483 484 # We don't need dropout info - we can always restore it based on the 485 # seed. 486 apply_dropout, seed = _dropout_params(mode, ensemble_stats) 487 # Make sure ensemble stats run. This will check that the ensemble has 488 # the right stamp. 489 with ops.control_dependencies(ensemble_stats): 490 leaf_index = None 491 if self._output_leaf_index and mode in self._output_leaf_index_modes: 492 predictions, _, leaf_index = ( 493 prediction_ops).gradient_trees_prediction_verbose( 494 ensemble_handle, 495 seed, 496 self._dense_floats, 497 self._sparse_float_indices, 498 self._sparse_float_values, 499 self._sparse_float_shapes, 500 self._sparse_int_indices, 501 self._sparse_int_values, 502 self._sparse_int_shapes, 503 learner_config=self._learner_config_serialized, 504 apply_dropout=apply_dropout, 505 apply_averaging=mode != learn.ModeKeys.TRAIN, 506 use_locking=True, 507 center_bias=self._center_bias, 508 reduce_dim=self._reduce_dim) 509 else: 510 leaf_index = None 511 predictions, _ = prediction_ops.gradient_trees_prediction( 512 ensemble_handle, 513 seed, 514 self._dense_floats, 515 self._sparse_float_indices, 516 self._sparse_float_values, 517 self._sparse_float_shapes, 518 self._sparse_int_indices, 519 self._sparse_int_values, 520 self._sparse_int_shapes, 521 learner_config=self._learner_config_serialized, 522 apply_dropout=apply_dropout, 523 apply_averaging=mode != learn.ModeKeys.TRAIN, 524 use_locking=True, 525 center_bias=self._center_bias, 526 reduce_dim=self._reduce_dim) 527 partition_ids = prediction_ops.gradient_trees_partition_examples( 528 ensemble_handle, 529 self._dense_floats, 530 self._sparse_float_indices, 531 self._sparse_float_values, 532 self._sparse_float_shapes, 533 self._sparse_int_indices, 534 self._sparse_int_values, 535 self._sparse_int_shapes, 536 use_locking=True) 537 538 return _make_predictions_dict(ensemble_stamp, predictions, partition_ids, 539 ensemble_stats, used_handlers, leaf_index) 540 541 def predict(self, mode): 542 """Returns predictions given the features and mode. 543 544 Args: 545 mode: Mode the graph is running in (train|predict|eval). 546 547 Returns: 548 A dict of predictions tensors. 549 550 Raises: 551 ValueError: if features is not valid. 552 """ 553 554 # Use the current ensemble to predict on the current batch of input. 555 # For faster prediction we check if the inputs are on the same device 556 # as the model. If not, we create a copy of the model on the worker. 557 input_deps = ( 558 self._dense_floats + self._sparse_float_indices + 559 self._sparse_int_indices) 560 if not input_deps: 561 raise ValueError("No input tensors for prediction.") 562 563 # Get most current model stamp. 564 ensemble_stamp = model_ops.tree_ensemble_stamp_token(self._ensemble_handle) 565 566 # Determine if ensemble is colocated with the inputs. 567 if self._ensemble_handle.device != input_deps[0].device: 568 # Create a local ensemble and get its local stamp. 569 with ops.name_scope("local_ensemble", "TreeEnsembleVariable"): 570 local_ensemble_handle = ( 571 gen_model_ops.decision_tree_ensemble_resource_handle_op( 572 self._ensemble_handle.op.name + "/local_ensemble")) 573 create_op = gen_model_ops.create_tree_ensemble_variable( 574 local_ensemble_handle, stamp_token=-1, tree_ensemble_config="") 575 with ops.control_dependencies([create_op]): 576 local_stamp = model_ops.tree_ensemble_stamp_token( 577 local_ensemble_handle) 578 579 # Determine whether the local ensemble is stale and update it if needed. 580 def _refresh_local_ensemble_fn(): 581 # Serialize the model from parameter server after reading the inputs. 582 with ops.control_dependencies([input_deps[0]]): 583 (ensemble_stamp, serialized_model) = ( 584 model_ops.tree_ensemble_serialize(self._ensemble_handle)) 585 586 # Update local ensemble with the serialized model from parameter server. 587 with ops.control_dependencies([create_op]): 588 return model_ops.tree_ensemble_deserialize( 589 local_ensemble_handle, 590 stamp_token=ensemble_stamp, 591 tree_ensemble_config=serialized_model), ensemble_stamp 592 593 refresh_local_ensemble, ensemble_stamp = control_flow_ops.cond( 594 math_ops.not_equal(ensemble_stamp, 595 local_stamp), _refresh_local_ensemble_fn, 596 lambda: (control_flow_ops.no_op(), ensemble_stamp)) 597 598 # Once updated, use the local model for prediction. 599 with ops.control_dependencies([refresh_local_ensemble]): 600 return self._predict_and_return_dict(local_ensemble_handle, 601 ensemble_stamp, mode) 602 else: 603 # Use ensemble_handle directly, if colocated. 604 with ops.device(self._ensemble_handle.device): 605 return self._predict_and_return_dict(self._ensemble_handle, 606 ensemble_stamp, mode) 607 608 def _get_class_id(self, predictions_dict): 609 # Handle different multiclass strategies. 610 if (self._learner_config.multi_class_strategy == 611 learner_pb2.LearnerConfig.TREE_PER_CLASS and 612 self._logits_dimension != 1): 613 # Choose the class for which the tree is built (one vs rest). 614 return math_ops.cast( 615 predictions_dict[NUM_TREES_ATTEMPTED] % self._logits_dimension, 616 dtypes.int32) 617 return constant_op.constant(-1, dtype=dtypes.int32) 618 619 def update_stats(self, loss, predictions_dict, gradients=None, hessians=None): 620 """Update the accumulators with stats from this batch. 621 622 Args: 623 loss: A scalar tensor representing average loss of examples. 624 predictions_dict: Dictionary of Rank 2 `Tensor` representing information 625 about predictions per example. 626 gradients: A tensor with the gradients with the respect to logits from 627 predictions_dict. If not provided, tensorflow will do 628 autodifferentiation. 629 hessians: A tensor with the hessians with the respect to logits from 630 predictions_dict. If not provided, tensorflow will do 631 autodifferentiation. 632 633 Returns: 634 Three values: 635 - An op that adds a new tree to the ensemble, and 636 - An op that increments the stamp but removes all the trees and resets 637 the handlers. This can be used to reset the state of the ensemble. 638 - A dict containing the training state. 639 640 Raises: 641 ValueError: if inputs are not valid. 642 """ 643 # Get the worker device from input dependencies. 644 input_deps = ( 645 self._dense_floats + self._sparse_float_indices + 646 self._sparse_int_indices) 647 worker_device = input_deps[0].device 648 649 # Get tensors relevant for training and form the loss. 650 predictions = predictions_dict[PREDICTIONS] 651 partition_ids = predictions_dict[PARTITION_IDS] 652 ensemble_stamp = predictions_dict[ENSEMBLE_STAMP] 653 if gradients is None: 654 gradients = gradients_impl.gradients( 655 loss, 656 predictions, 657 name="Gradients", 658 colocate_gradients_with_ops=False, 659 gate_gradients=0, 660 aggregation_method=None)[0] 661 strategy = self._learner_config.multi_class_strategy 662 663 class_id = self._get_class_id(predictions_dict) 664 # Handle different multiclass strategies. 665 if strategy == learner_pb2.LearnerConfig.TREE_PER_CLASS: 666 # We build one vs rest trees. 667 if self._logits_dimension == 1: 668 # We have only 1 score, gradients is of shape [batch, 1]. 669 if hessians is None: 670 hessians = gradients_impl.gradients( 671 gradients, 672 predictions, 673 name="Hessian", 674 colocate_gradients_with_ops=False, 675 gate_gradients=0, 676 aggregation_method=None)[0] 677 678 squeezed_gradients = array_ops.squeeze(gradients, axis=[1]) 679 squeezed_hessians = array_ops.squeeze(hessians, axis=[1]) 680 else: 681 if hessians is not None: 682 raise ValueError("Providing hessians is not yet supported here.") 683 hessian_list = self._diagonal_hessian(gradients, predictions) 684 # Assemble hessian list into a tensor. 685 hessians = array_ops.stack(hessian_list, axis=1) 686 # Use class id tensor to get the column with that index from gradients 687 # and hessians. 688 squeezed_gradients = array_ops.squeeze( 689 _get_column_by_index(gradients, class_id)) 690 squeezed_hessians = array_ops.squeeze( 691 _get_column_by_index(hessians, class_id)) 692 else: 693 if hessians is not None: 694 raise ValueError("Providing hessians is not yet supported here.") 695 # Other multiclass strategies. 696 if strategy == learner_pb2.LearnerConfig.FULL_HESSIAN: 697 hessian_list = self._full_hessian(gradients, predictions) 698 else: 699 # Diagonal hessian strategy. 700 hessian_list = self._diagonal_hessian(gradients, predictions) 701 702 squeezed_gradients = gradients 703 hessians = array_ops.stack(hessian_list, axis=1) 704 squeezed_hessians = hessians 705 706 # Get the weights for each example for quantiles calculation, 707 weights = self._get_weights(self._hessian_shape, squeezed_hessians) 708 709 # Create all handlers ensuring resources are evenly allocated across PS. 710 fc_name_idx = 0 711 handlers = [] 712 init_stamp_token = constant_op.constant(0, dtype=dtypes.int64) 713 l1_regularization = constant_op.constant( 714 self._learner_config.regularization.l1, dtypes.float32) 715 l2_regularization = constant_op.constant( 716 self._learner_config.regularization.l2, dtypes.float32) 717 tree_complexity_regularization = constant_op.constant( 718 self._learner_config.regularization.tree_complexity, dtypes.float32) 719 min_node_weight = constant_op.constant( 720 self._learner_config.constraints.min_node_weight, dtypes.float32) 721 loss_uses_sum_reduction = self._loss_reduction == losses.Reduction.SUM 722 loss_uses_sum_reduction = constant_op.constant(loss_uses_sum_reduction) 723 weak_learner_type = constant_op.constant( 724 self._learner_config.weak_learner_type) 725 num_quantiles = self._num_quantiles 726 epsilon = 1.0 / num_quantiles 727 strategy_tensor = constant_op.constant(strategy) 728 with ops.device(self._get_replica_device_setter(worker_device)): 729 # Create handlers for dense float columns 730 for dense_float_column_idx in range(len(self._dense_floats)): 731 fc_name = self._fc_names[fc_name_idx] 732 handlers.append( 733 ordinal_split_handler.DenseSplitHandler( 734 l1_regularization=l1_regularization, 735 l2_regularization=l2_regularization, 736 tree_complexity_regularization=tree_complexity_regularization, 737 min_node_weight=min_node_weight, 738 feature_column_group_id=constant_op.constant( 739 dense_float_column_idx), 740 epsilon=epsilon, 741 num_quantiles=num_quantiles, 742 dense_float_column=self._dense_floats[dense_float_column_idx], 743 name=fc_name, 744 gradient_shape=self._gradient_shape, 745 hessian_shape=self._hessian_shape, 746 multiclass_strategy=strategy_tensor, 747 init_stamp_token=init_stamp_token, 748 loss_uses_sum_reduction=loss_uses_sum_reduction, 749 weak_learner_type=weak_learner_type, 750 )) 751 fc_name_idx += 1 752 753 # Create handlers for sparse float columns. 754 for sparse_float_column_idx in range(len(self._sparse_float_indices)): 755 fc_name = self._fc_names[fc_name_idx] 756 handlers.append( 757 ordinal_split_handler.SparseSplitHandler( 758 l1_regularization=l1_regularization, 759 l2_regularization=l2_regularization, 760 tree_complexity_regularization=tree_complexity_regularization, 761 min_node_weight=min_node_weight, 762 feature_column_group_id=constant_op.constant( 763 sparse_float_column_idx), 764 epsilon=epsilon, 765 num_quantiles=num_quantiles, 766 sparse_float_column=sparse_tensor.SparseTensor( 767 self._sparse_float_indices[sparse_float_column_idx], 768 self._sparse_float_values[sparse_float_column_idx], 769 self._sparse_float_shapes[sparse_float_column_idx]), 770 name=fc_name, 771 gradient_shape=self._gradient_shape, 772 hessian_shape=self._hessian_shape, 773 multiclass_strategy=strategy_tensor, 774 init_stamp_token=init_stamp_token, 775 loss_uses_sum_reduction=loss_uses_sum_reduction)) 776 fc_name_idx += 1 777 778 # Create handlers for sparse int columns. 779 for sparse_int_column_idx in range(len(self._sparse_int_indices)): 780 fc_name = self._fc_names[fc_name_idx] 781 handlers.append( 782 categorical_split_handler.EqualitySplitHandler( 783 l1_regularization=l1_regularization, 784 l2_regularization=l2_regularization, 785 tree_complexity_regularization=tree_complexity_regularization, 786 min_node_weight=min_node_weight, 787 feature_column_group_id=constant_op.constant( 788 sparse_int_column_idx), 789 sparse_int_column=sparse_tensor.SparseTensor( 790 self._sparse_int_indices[sparse_int_column_idx], 791 self._sparse_int_values[sparse_int_column_idx], 792 self._sparse_int_shapes[sparse_int_column_idx]), 793 name=fc_name, 794 gradient_shape=self._gradient_shape, 795 hessian_shape=self._hessian_shape, 796 multiclass_strategy=strategy_tensor, 797 init_stamp_token=init_stamp_token, 798 loss_uses_sum_reduction=loss_uses_sum_reduction, 799 weak_learner_type=weak_learner_type)) 800 fc_name_idx += 1 801 802 # Create ensemble stats variables. 803 num_layer_examples = variables.VariableV1( 804 initial_value=array_ops.zeros([], dtypes.int64), 805 name="num_layer_examples", 806 trainable=False) 807 num_layer_steps = variables.VariableV1( 808 initial_value=array_ops.zeros([], dtypes.int64), 809 name="num_layer_steps", 810 trainable=False) 811 num_layers = variables.VariableV1( 812 initial_value=array_ops.zeros([], dtypes.int64), 813 name="num_layers", 814 trainable=False) 815 active_tree = variables.VariableV1( 816 initial_value=array_ops.zeros([], dtypes.int64), 817 name="active_tree", 818 trainable=False) 819 active_layer = variables.VariableV1( 820 initial_value=array_ops.zeros([], dtypes.int64), 821 name="active_layer", 822 trainable=False) 823 # Variable that becomes false once bias centering is done. 824 continue_centering = variables.VariableV1( 825 initial_value=self._center_bias, 826 name="continue_centering", 827 trainable=False) 828 # Create bias stats accumulator. 829 bias_stats_accumulator = stats_accumulator_ops.StatsAccumulator( 830 stamp_token=0, 831 gradient_shape=self._gradient_shape, 832 hessian_shape=self._hessian_shape, 833 name="BiasAccumulator") 834 # Create steps accumulator. 835 steps_accumulator = stats_accumulator_ops.StatsAccumulator( 836 stamp_token=0, 837 gradient_shape=tensor_shape.scalar(), 838 hessian_shape=tensor_shape.scalar(), 839 name="StepsAccumulator") 840 # Create ensemble stats summaries. 841 summary.scalar("layer_stats/num_examples", num_layer_examples) 842 summary.scalar("layer_stats/num_steps", num_layer_steps) 843 summary.scalar("ensemble_stats/active_tree", active_tree) 844 summary.scalar("ensemble_stats/active_layer", active_layer) 845 846 # Update bias stats. 847 stats_update_ops = [] 848 849 stats_update_ops.append( 850 control_flow_ops.cond( 851 continue_centering, 852 self._make_update_bias_stats_fn(ensemble_stamp, predictions, 853 gradients, bias_stats_accumulator, 854 hessians), control_flow_ops.no_op)) 855 856 # Update handler stats. 857 handler_reads = collections.OrderedDict() 858 for handler in handlers: 859 handler_reads[handler] = handler.scheduled_reads() 860 861 handler_results = batch_ops_utils.run_handler_scheduled_ops( 862 handler_reads, ensemble_stamp, worker_device) 863 per_handler_updates = collections.OrderedDict() 864 # Two values per handler. First one is if the handler is active for the 865 # current layer. The second one is if the handler is going to be active 866 # for the next layer. 867 subsampling_type = self._learner_config.WhichOneof("feature_fraction") 868 if subsampling_type == "feature_fraction_per_level": 869 seed = predictions_dict[NUM_LAYERS_ATTEMPTED] 870 active_handlers_current_layer = stateless.stateless_random_uniform( 871 shape=[len(handlers)], seed=[seed, 1]) 872 active_handlers_next_layer = stateless.stateless_random_uniform( 873 shape=[len(handlers)], seed=[seed + 1, 1]) 874 active_handlers = array_ops.stack( 875 [active_handlers_current_layer, active_handlers_next_layer], axis=1) 876 active_handlers = ( 877 active_handlers < self._learner_config.feature_fraction_per_level) 878 elif subsampling_type == "feature_fraction_per_tree": 879 seed = predictions_dict[NUM_TREES_ATTEMPTED] 880 active_handlers_current_layer = stateless.stateless_random_uniform( 881 shape=[len(handlers)], seed=[seed, 2]) 882 active_handlers_current_layer = ( 883 active_handlers_current_layer < 884 self._learner_config.feature_fraction_per_tree) 885 active_handlers = array_ops.stack( 886 [ 887 active_handlers_current_layer, 888 array_ops.ones([len(handlers)], dtype=dtypes.bool) 889 ], 890 axis=1) 891 else: 892 active_handlers = array_ops.ones([len(handlers), 2], dtype=dtypes.bool) 893 894 if self._learner_config.constraints.max_number_of_unique_feature_columns: 895 target = ( 896 self._learner_config.constraints.max_number_of_unique_feature_columns) 897 898 def _feature_selection_active_handlers(): 899 # The active list for current and the next iteration. 900 used_handlers = array_ops.reshape(predictions_dict[USED_HANDLERS_MASK], 901 [-1, 1]) 902 used_handlers = array_ops.concat([used_handlers, used_handlers], axis=1) 903 return math_ops.logical_and(used_handlers, active_handlers) 904 905 active_handlers = ( 906 control_flow_ops.cond(predictions_dict[NUM_USED_HANDLERS] >= target, 907 _feature_selection_active_handlers, 908 lambda: active_handlers)) 909 910 # Prepare empty gradients and hessians when handlers are not ready. 911 empty_hess_shape = [1] + self._hessian_shape.as_list() 912 empty_grad_shape = [1] + self._gradient_shape.as_list() 913 914 empty_gradients = constant_op.constant_v1( 915 [], dtype=dtypes.float32, shape=empty_grad_shape) 916 empty_hessians = constant_op.constant_v1( 917 [], dtype=dtypes.float32, shape=empty_hess_shape) 918 919 active_handlers = array_ops.unstack(active_handlers, axis=0) 920 for handler_idx in range(len(handlers)): 921 handler = handlers[handler_idx] 922 is_active = active_handlers[handler_idx] 923 updates, scheduled_updates = handler.update_stats( 924 ensemble_stamp, partition_ids, squeezed_gradients, squeezed_hessians, 925 empty_gradients, empty_hessians, weights, is_active, 926 handler_results[handler]) 927 stats_update_ops.append(updates) 928 per_handler_updates[handler] = scheduled_updates 929 930 update_results = batch_ops_utils.run_handler_scheduled_ops( 931 per_handler_updates, ensemble_stamp, worker_device) 932 for update in update_results.values(): 933 stats_update_ops += update 934 935 training_state = GBDTTrainingState( 936 num_layer_examples=num_layer_examples, 937 num_layer_steps=num_layer_steps, 938 num_layers=num_layers, 939 active_tree=active_tree, 940 active_layer=active_layer, 941 continue_centering=continue_centering, 942 bias_stats_accumulator=bias_stats_accumulator, 943 steps_accumulator=steps_accumulator, 944 handlers=handlers) 945 946 reset_op = control_flow_ops.no_op() 947 if self._is_chief: 948 # Advance the ensemble stamp to throw away staggered workers. 949 stamp_token, _ = model_ops.tree_ensemble_serialize(self._ensemble_handle) 950 next_stamp_token = stamp_token + 1 951 952 reset_ops = [] 953 for handler in handlers: 954 reset_ops.append(handler.reset(stamp_token, next_stamp_token)) 955 if self._center_bias: 956 reset_ops.append( 957 bias_stats_accumulator.flush(stamp_token, next_stamp_token)) 958 reset_ops.append(steps_accumulator.flush(stamp_token, next_stamp_token)) 959 reset_ops.append(self._finalized_trees.assign(0).op) 960 reset_ops.append(self._attempted_trees.assign(0).op) 961 reset_ops.append( 962 model_ops.tree_ensemble_deserialize( 963 self._ensemble_handle, 964 stamp_token=next_stamp_token, 965 tree_ensemble_config="", 966 name="reset_gbdt")) 967 968 reset_op = control_flow_ops.group([reset_ops]) 969 970 return stats_update_ops, reset_op, training_state 971 972 def increment_step_counter_and_maybe_update_ensemble(self, predictions_dict, 973 training_state): 974 """Increments number of visited examples and grows the ensemble. 975 976 If the number of visited examples reaches the target examples_per_layer, 977 ensemble is updated. 978 979 Args: 980 predictions_dict: Dictionary of Rank 2 `Tensor` representing information 981 about predictions per example. 982 training_state: `dict` returned by update_stats. 983 984 Returns: 985 An op that updates the counters and potientially grows the ensemble. 986 """ 987 batch_size = math_ops.cast( 988 array_ops.shape(predictions_dict[PREDICTIONS])[0], dtypes.float32) 989 ensemble_stamp = predictions_dict[ENSEMBLE_STAMP] 990 # Accumulate a step after updating stats. 991 992 steps_accumulator = training_state.steps_accumulator 993 num_layer_examples = training_state.num_layer_examples 994 num_layer_steps = training_state.num_layer_steps 995 active_layer = training_state.active_layer 996 add_step_op = steps_accumulator.add( 997 ensemble_stamp, [0], [[0, 0]], [batch_size], [1.0]) 998 999 # After adding the step, decide if further processing is needed. 1000 ensemble_update_ops = [add_step_op] 1001 class_id = self._get_class_id(predictions_dict) 1002 1003 with ops.control_dependencies([add_step_op]): 1004 if self._is_chief: 1005 dropout_seed = predictions_dict[NUM_TREES_ATTEMPTED] 1006 1007 # Get accumulated steps and examples for the current layer. 1008 _, _, _, _, acc_examples, acc_steps = ( 1009 steps_accumulator.saveable.serialize()) 1010 acc_examples = math_ops.cast(acc_examples[0], dtypes.int64) 1011 acc_steps = math_ops.cast(acc_steps[0], dtypes.int64) 1012 ensemble_update_ops.append( 1013 num_layer_examples.assign(acc_examples)) 1014 ensemble_update_ops.append(num_layer_steps.assign(acc_steps)) 1015 # Determine whether we need to update tree ensemble. 1016 examples_per_layer = self._examples_per_layer 1017 if callable(examples_per_layer): 1018 examples_per_layer = examples_per_layer(active_layer) 1019 ensemble_update_ops.append( 1020 control_flow_ops.cond( 1021 acc_examples >= examples_per_layer, 1022 self.make_update_ensemble_fn(ensemble_stamp, training_state, 1023 dropout_seed, class_id), 1024 control_flow_ops.no_op)) 1025 1026 # Note, the loss is calculated from the prediction considering dropouts, so 1027 # that the value might look staggering over steps when the dropout ratio is 1028 # high. eval_loss might be referred instead in the aspect of convergence. 1029 return control_flow_ops.group(*ensemble_update_ops) 1030 1031 def make_update_ensemble_fn(self, ensemble_stamp, training_state, 1032 dropout_seed, class_id): 1033 """A method to create the function which updates the tree ensemble.""" 1034 # Determine learning rate. 1035 learning_rate_tuner = self._learner_config.learning_rate_tuner.WhichOneof( 1036 "tuner") 1037 if learning_rate_tuner == "fixed" or learning_rate_tuner == "dropout": 1038 tuner = getattr(self._learner_config.learning_rate_tuner, 1039 learning_rate_tuner) 1040 learning_rate = tuner.learning_rate 1041 else: 1042 # TODO(nponomareva, soroush) do the line search. 1043 raise ValueError("Line search learning rate is not yet supported.") 1044 1045 def _update_ensemble(): 1046 """A method to update the tree ensemble.""" 1047 # Get next stamp token. 1048 next_ensemble_stamp = ensemble_stamp + 1 1049 # Finalize bias stats. 1050 _, _, _, bias_grads, bias_hess = ( 1051 training_state.bias_stats_accumulator.flush(ensemble_stamp, 1052 next_ensemble_stamp)) 1053 1054 # Finalize handler splits. 1055 are_splits_ready_list = [] 1056 partition_ids_list = [] 1057 gains_list = [] 1058 split_info_list = [] 1059 1060 for handler in training_state.handlers: 1061 (are_splits_ready, 1062 partition_ids, gains, split_info) = handler.make_splits( 1063 ensemble_stamp, next_ensemble_stamp, class_id) 1064 are_splits_ready_list.append(are_splits_ready) 1065 partition_ids_list.append(partition_ids) 1066 gains_list.append(gains) 1067 split_info_list.append(split_info) 1068 # Stack all the inputs to one tensor per type. 1069 # This is a workaround for the slowness of graph building in tf.cond. 1070 # See (b/36554864). 1071 split_sizes = array_ops.reshape( 1072 array_ops.shape_n(partition_ids_list), [len(partition_ids_list)]) 1073 partition_ids = array_ops.concat(partition_ids_list, axis=0) 1074 gains = array_ops.concat(gains_list, axis=0) 1075 split_infos = array_ops.concat(split_info_list, axis=0) 1076 1077 # Determine if all splits are ready. 1078 are_all_splits_ready = math_ops.reduce_all( 1079 array_ops.stack( 1080 are_splits_ready_list, axis=0, name="stack_handler_readiness")) 1081 1082 # Define bias centering update operation. 1083 def _center_bias_fn(): 1084 # Center tree ensemble bias. 1085 delta_updates = array_ops.where(bias_hess > 0, -bias_grads / bias_hess, 1086 array_ops.zeros_like(bias_grads)) 1087 center_bias = training_ops.center_tree_ensemble_bias( 1088 tree_ensemble_handle=self._ensemble_handle, 1089 stamp_token=ensemble_stamp, 1090 next_stamp_token=next_ensemble_stamp, 1091 delta_updates=delta_updates, 1092 learner_config=self._learner_config_serialized) 1093 return training_state.continue_centering.assign(center_bias) 1094 1095 # Define ensemble growing operations. 1096 def _grow_ensemble_ready_fn(): 1097 # Grow the ensemble given the current candidates. 1098 sizes = array_ops.unstack(split_sizes) 1099 partition_ids_list = list(array_ops.split(partition_ids, sizes, axis=0)) 1100 # When using the oblivious decision tree as weak learner, it produces 1101 # one gain and one split per handler and not number of partitions. 1102 if self._learner_config.weak_learner_type == ( 1103 learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE): 1104 sizes = len(training_state.handlers) 1105 1106 gains_list = list(array_ops.split(gains, sizes, axis=0)) 1107 split_info_list = list(array_ops.split(split_infos, sizes, axis=0)) 1108 return training_ops.grow_tree_ensemble( 1109 tree_ensemble_handle=self._ensemble_handle, 1110 stamp_token=ensemble_stamp, 1111 next_stamp_token=next_ensemble_stamp, 1112 learning_rate=learning_rate, 1113 partition_ids=partition_ids_list, 1114 gains=gains_list, 1115 splits=split_info_list, 1116 learner_config=self._learner_config_serialized, 1117 dropout_seed=dropout_seed, 1118 center_bias=self._center_bias, 1119 max_tree_depth=self._max_tree_depth, 1120 weak_learner_type=self._learner_config.weak_learner_type) 1121 1122 def _grow_ensemble_not_ready_fn(): 1123 # Don't grow the ensemble, just update the stamp. 1124 return training_ops.grow_tree_ensemble( 1125 tree_ensemble_handle=self._ensemble_handle, 1126 stamp_token=ensemble_stamp, 1127 next_stamp_token=next_ensemble_stamp, 1128 learning_rate=0, 1129 partition_ids=[], 1130 gains=[], 1131 splits=[], 1132 learner_config=self._learner_config_serialized, 1133 dropout_seed=dropout_seed, 1134 center_bias=self._center_bias, 1135 max_tree_depth=self._max_tree_depth, 1136 weak_learner_type=self._learner_config.weak_learner_type) 1137 1138 def _grow_ensemble_fn(): 1139 # Conditionally grow an ensemble depending on whether the splits 1140 # from all the handlers are ready. 1141 return control_flow_ops.cond(are_all_splits_ready, 1142 _grow_ensemble_ready_fn, 1143 _grow_ensemble_not_ready_fn) 1144 1145 # Update ensemble. 1146 update_ops = [are_all_splits_ready] 1147 if self._center_bias: 1148 update_model = control_flow_ops.cond(training_state.continue_centering, 1149 _center_bias_fn, _grow_ensemble_fn) 1150 else: 1151 update_model = _grow_ensemble_fn() 1152 update_ops.append(update_model) 1153 1154 # Update ensemble stats. 1155 with ops.control_dependencies([update_model]): 1156 stats = training_ops.tree_ensemble_stats( 1157 self._ensemble_handle, stamp_token=next_ensemble_stamp) 1158 update_ops.append(self._finalized_trees.assign(stats.num_trees)) 1159 update_ops.append(self._attempted_trees.assign(stats.attempted_trees)) 1160 update_ops.append(training_state.num_layers.assign(stats.num_layers)) 1161 update_ops.append(training_state.active_tree.assign(stats.active_tree)) 1162 update_ops.append( 1163 training_state.active_layer.assign(stats.active_layer)) 1164 1165 # Flush step stats. 1166 update_ops.extend( 1167 training_state.steps_accumulator.flush(ensemble_stamp, 1168 next_ensemble_stamp)) 1169 return control_flow_ops.group(*update_ops, name="update_ensemble") 1170 1171 return _update_ensemble 1172 1173 def get_number_of_trees_tensor(self): 1174 return self._finalized_trees, self._attempted_trees 1175 1176 def get_max_tree_depth(self): 1177 return self._max_tree_depth 1178 1179 def train(self, loss, predictions_dict, labels, gradients=None, 1180 hessians=None): 1181 """Updates the accumalator stats and grows the ensemble. 1182 1183 Args: 1184 loss: A scalar tensor representing average loss of examples. 1185 predictions_dict: Dictionary of Rank 2 `Tensor` representing information 1186 about predictions per example. 1187 labels: Rank 2 `Tensor` representing labels per example. Has no effect 1188 on the training and is only kept for backward compatibility. 1189 gradients: A tensor with the gradients with the respect to logits from 1190 predictions_dict. If not provided, tensorflow will do 1191 autodifferentiation. 1192 hessians: A tensor with the hessians with the respect to logits from 1193 predictions_dict. If not provided, tensorflow will do 1194 autodifferentiation. 1195 1196 Returns: 1197 An op that adds a new tree to the ensemble. 1198 1199 Raises: 1200 ValueError: if inputs are not valid. 1201 """ 1202 del labels # unused; kept for backward compatibility. 1203 update_op, _, training_state = self.update_stats(loss, predictions_dict, 1204 gradients, hessians) 1205 with ops.control_dependencies(update_op): 1206 return self.increment_step_counter_and_maybe_update_ensemble( 1207 predictions_dict, training_state) 1208 1209 def _get_weights(self, hessian_shape, hessians): 1210 """Derives weights to be used based on hessians and multiclass strategy.""" 1211 if hessian_shape == tensor_shape.scalar(): 1212 # This is tree per class. 1213 weights = hessians 1214 elif len(hessian_shape.dims) == 1: 1215 # This is diagonal hessian. 1216 weights = math_ops.reduce_sum(hessians, axis=1) 1217 else: 1218 # This is full hessian. 1219 weights = math_ops.trace(hessians) 1220 return weights 1221 1222 def _full_hessian(self, grads, predictions): 1223 """Prepares hessians for full-hessian multiclass strategy.""" 1224 # Because of 1225 # https://github.com/tensorflow/tensorflow/issues/675, we can't just 1226 # compute the full hessian with a single call to gradients, but instead 1227 # must compute it row-by-row. 1228 gradients_list = array_ops.unstack( 1229 grads, num=self._logits_dimension, axis=1) 1230 hessian_rows = [] 1231 1232 for row in range(self._logits_dimension): 1233 # If current row is i, K is number of classes,each row returns a tensor of 1234 # size batch_size x K representing for each example dx_i dx_1, dx_i dx_2 1235 # etc dx_i dx_K 1236 hessian_row = gradients_impl.gradients( 1237 gradients_list[row], 1238 predictions, 1239 name="Hessian_%d" % row, 1240 colocate_gradients_with_ops=False, 1241 gate_gradients=0, 1242 aggregation_method=None) 1243 1244 # hessian_row is of dimension 1, batch_size, K, => trim first dimension 1245 # to get batch_size x K 1246 hessian_row = array_ops.squeeze(array_ops.unstack(hessian_row), [0]) 1247 hessian_rows.append(hessian_row) 1248 return hessian_rows 1249 1250 def _diagonal_hessian(self, grads, predictions): 1251 """Prepares hessians for diagonal-hessian multiclass mode.""" 1252 diag_hessian_list = [] 1253 1254 gradients_list = array_ops.unstack( 1255 grads, num=self._logits_dimension, axis=1) 1256 1257 for row, row_grads in enumerate(gradients_list): 1258 # If current row is i, K is number of classes,each row returns a tensor of 1259 # size batch_size x K representing for each example dx_i dx_1, dx_1 dx_2 1260 # etc dx_i dx_K 1261 hessian_row = gradients_impl.gradients( 1262 row_grads, 1263 predictions, 1264 name="Hessian_%d" % row, 1265 colocate_gradients_with_ops=False, 1266 gate_gradients=0, 1267 aggregation_method=None) 1268 1269 # hessian_row is of dimension 1, batch_size, K, => trim first dimension 1270 # to get batch_size x K 1271 hessian_row = array_ops.squeeze(array_ops.unstack(hessian_row), [0]) 1272 1273 # Get dx_i^2 for the whole batch. 1274 elem = array_ops.transpose(hessian_row)[row] 1275 diag_hessian_list.append(elem) 1276 1277 return diag_hessian_list 1278 1279 def _get_replica_device_setter(self, worker_device): 1280 """Creates a replica device setter.""" 1281 ps_tasks = self._num_ps_replicas 1282 ps_ops = list(device_setter.STANDARD_PS_OPS) 1283 ps_ops.extend([ 1284 "DecisionTreeEnsembleResourceHandleOp", 1285 "StatsAccumulatorScalarResourceHandleOp", 1286 "StatsAccumulatorTensorResourceHandleOp", 1287 ]) 1288 ps_strategy = _OpRoundRobinStrategy(ps_ops, ps_tasks) 1289 return device_setter.replica_device_setter( 1290 worker_device=worker_device, 1291 ps_tasks=ps_tasks, 1292 merge_devices=True, 1293 ps_ops=ps_ops, 1294 ps_strategy=ps_strategy) 1295 1296 def _make_update_bias_stats_fn(self, 1297 ensemble_stamp, 1298 predictions, 1299 gradients, 1300 bias_stats_accumulator, 1301 hessians=None): 1302 """A method to create the function which updates the bias stats.""" 1303 1304 def _update_bias_stats(): 1305 """A method to update the bias stats.""" 1306 # Get reduced gradients and hessians. 1307 grads_sum = math_ops.reduce_sum(gradients, 0) 1308 if hessians is not None: 1309 hess = hessians 1310 else: 1311 hess = gradients_impl.gradients( 1312 grads_sum, 1313 predictions, 1314 name="Hessians", 1315 colocate_gradients_with_ops=False, 1316 gate_gradients=0, 1317 aggregation_method=None)[0] 1318 hess_sum = math_ops.reduce_sum(hess, 0) 1319 1320 # Accumulate gradients and hessians. 1321 partition_ids = math_ops.range(self._logits_dimension) 1322 feature_ids = array_ops.zeros( 1323 [self._logits_dimension, 2], dtype=dtypes.int64) 1324 1325 add_stats_op = bias_stats_accumulator.add( 1326 ensemble_stamp, partition_ids, feature_ids, grads_sum, hess_sum) 1327 return control_flow_ops.group(*[add_stats_op], name="update_bias_stats") 1328 1329 return _update_bias_stats 1330