• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Training functions for Gradient boosted decision trees."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import copy
23
24from tensorflow.contrib import learn
25from tensorflow.contrib.boosted_trees.lib.learner.batch import categorical_split_handler
26from tensorflow.contrib.boosted_trees.lib.learner.batch import ordinal_split_handler
27from tensorflow.contrib.boosted_trees.proto import learner_pb2
28from tensorflow.contrib.boosted_trees.python.ops import batch_ops_utils
29from tensorflow.contrib.boosted_trees.python.ops import gen_model_ops
30from tensorflow.contrib.boosted_trees.python.ops import model_ops
31from tensorflow.contrib.boosted_trees.python.ops import prediction_ops
32from tensorflow.contrib.boosted_trees.python.ops import stats_accumulator_ops
33from tensorflow.contrib.boosted_trees.python.ops import training_ops
34from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib
35from tensorflow.contrib.layers.python.layers import feature_column_ops
36from tensorflow.python.feature_column import feature_column as fc_core
37from tensorflow.python.framework import constant_op
38from tensorflow.python.framework import dtypes
39from tensorflow.python.framework import ops
40from tensorflow.python.framework import sparse_tensor
41from tensorflow.python.framework import tensor_shape
42from tensorflow.python.ops import array_ops
43from tensorflow.python.ops import control_flow_ops
44from tensorflow.python.ops import gradients_impl
45from tensorflow.python.ops import math_ops
46from tensorflow.python.ops import stateless_random_ops as stateless
47from tensorflow.python.ops import variable_scope
48from tensorflow.python.ops import variables
49from tensorflow.python.ops.losses import losses
50from tensorflow.python.platform import tf_logging as logging
51from tensorflow.python.summary import summary
52from tensorflow.python.training import device_setter
53
54
55# Key names for prediction dict.
56ENSEMBLE_STAMP = "ensemble_stamp"
57PREDICTIONS = "predictions"
58PARTITION_IDS = "partition_ids"
59NUM_LAYERS_ATTEMPTED = "num_layers"
60NUM_TREES_ATTEMPTED = "num_trees"
61NUM_USED_HANDLERS = "num_used_handlers"
62USED_HANDLERS_MASK = "used_handlers_mask"
63LEAF_INDEX = "leaf_index"
64_FEATURE_NAME_TEMPLATE = "%s_%d"
65
66# Keys in Training state.
67GBDTTrainingState = collections.namedtuple("GBDTTrainingState", [
68    "num_layer_examples", "num_layer_steps", "num_layers", "active_tree",
69    "active_layer", "continue_centering", "bias_stats_accumulator",
70    "steps_accumulator", "handlers"
71])
72
73
74def _get_column_by_index(tensor, indices):
75  """Returns columns from a 2-D tensor by index."""
76  shape = array_ops.shape(tensor)
77  p_flat = array_ops.reshape(tensor, [-1])
78  i_flat = array_ops.reshape(
79      array_ops.reshape(math_ops.range(0, shape[0]) * shape[1], [-1, 1]) +
80      indices, [-1])
81  return array_ops.reshape(array_ops.gather(p_flat, i_flat), [shape[0], -1])
82
83
84def _make_predictions_dict(stamp,
85                           logits,
86                           partition_ids,
87                           ensemble_stats,
88                           used_handlers,
89                           leaf_index=None):
90  """Returns predictions for the given logits and n_classes.
91
92  Args:
93    stamp: The ensemble stamp.
94    logits: A rank 2 `Tensor` with shape [batch_size, n_classes - 1]. that
95      contains predictions when no dropout was applied.
96    partition_ids: A rank 1 `Tensor` with shape [batch_size].
97    ensemble_stats: A TreeEnsembleStatsOp result tuple.
98    used_handlers: A TreeEnsembleUsedHandlerOp result tuple of an int and a
99      boolean mask.
100    leaf_index: A rank 2 `Tensor` with shape [batch_size, number of trees]. that
101      contains leaf id for each example prediction.
102
103  Returns:
104    A dict of predictions.
105  """
106  result = {}
107  result[ENSEMBLE_STAMP] = stamp
108  result[PREDICTIONS] = logits
109  result[PARTITION_IDS] = partition_ids
110  result[NUM_LAYERS_ATTEMPTED] = ensemble_stats.attempted_layers
111  result[NUM_TREES_ATTEMPTED] = ensemble_stats.attempted_trees
112  result[NUM_USED_HANDLERS] = used_handlers.num_used_handlers
113  result[USED_HANDLERS_MASK] = used_handlers.used_handlers_mask
114  if leaf_index is not None:
115    result[LEAF_INDEX] = leaf_index
116  return result
117
118
119class _OpRoundRobinStrategy(object):
120  """Returns the next ps task index for placement via per-Op round-robin order.
121
122  This strategy works slightly better for the GBDT graph because of using
123  custom resources which vary significantly in compute cost.
124  """
125
126  def __init__(self, ps_ops, num_tasks):
127    """Create a new `_RoundRobinStrategy`.
128
129    Args:
130      ps_ops: List of Op types to place on PS.
131      num_tasks: Number of ps tasks to cycle among.
132    """
133    next_task = 0
134    self._next_task_per_op = {}
135    for op in ps_ops:
136      self._next_task_per_op[op] = next_task
137      next_task = (next_task + 1) % num_tasks if num_tasks else 0
138    self._num_tasks = num_tasks
139
140  def __call__(self, op):
141    """Choose a ps task index for the given `Operation`.
142
143    Args:
144      op: An `Operation` to be placed on ps.
145
146    Returns:
147      The next ps task index to use for the `Operation`. Returns the next
148      index, in the range `[offset, offset + num_tasks)`.
149
150    Raises:
151      ValueError: If attempting to place non-PS Op.
152    """
153    if op.type not in self._next_task_per_op:
154      raise ValueError("Unknown op type '%s' for placement:" % op.type)
155    task = self._next_task_per_op[op.type]
156    self._next_task_per_op[op.type] = ((task + 1) % self._num_tasks
157                                       if self._num_tasks else 0)
158    return task
159
160
161def extract_features(features, feature_columns, use_core_columns):
162  """Extracts columns from a dictionary of features.
163
164  Args:
165    features: `dict` of `Tensor` objects.
166    feature_columns: A list of feature_columns.
167
168  Returns:
169    Seven values:
170      - A list of all feature column names.
171      - A list of dense floats.
172      - A list of sparse float feature indices.
173      - A list of sparse float feature values.
174      - A list of sparse float feature shapes.
175      - A list of sparse int feature indices.
176      - A list of sparse int feature values.
177      - A list of sparse int feature shapes.
178  Raises:
179    ValueError: if features is not valid.
180  """
181  if not features:
182    raise ValueError("Features dictionary must be specified.")
183
184  # Make a shallow copy of features to ensure downstream usage
185  # is unaffected by modifications in the model function.
186  features = copy.copy(features)
187  if feature_columns:
188    scope = "gbdt"
189    with variable_scope.variable_scope(scope):
190      feature_columns = list(feature_columns)
191      transformed_features = collections.OrderedDict()
192      for fc in feature_columns:
193        # pylint: disable=protected-access
194        if use_core_columns:
195          # pylint: disable=protected-access
196          tensor = fc_core._transform_features(features, [fc])[fc]
197          transformed_features[fc.name] = tensor
198        elif isinstance(fc, feature_column_lib._EmbeddingColumn):
199          # pylint: enable=protected-access
200          transformed_features[fc.name] = fc_core.input_layer(
201              features, [fc], weight_collections=[scope])
202        else:
203          result = feature_column_ops.transform_features(features, [fc])
204          if len(result) > 1:
205            raise ValueError("Unexpected number of output features")
206          transformed_features[fc.name] = result[list(result.keys())[0]]
207    features = transformed_features
208
209  dense_float_names = []
210  dense_floats = []
211  sparse_float_names = []
212  sparse_float_indices = []
213  sparse_float_values = []
214  sparse_float_shapes = []
215  sparse_int_names = []
216  sparse_int_indices = []
217  sparse_int_values = []
218  sparse_int_shapes = []
219  for key in sorted(features.keys()):
220    tensor = features[key]
221    # TODO(nponomareva): consider iterating over feature columns instead.
222    if isinstance(tensor, tuple):
223      # Weighted categorical feature.
224      categorical_tensor = tensor[0]
225      weight_tensor = tensor[1]
226
227      shape = categorical_tensor.dense_shape
228      indices = array_ops.concat([
229          array_ops.slice(categorical_tensor.indices, [0, 0], [-1, 1]),
230          array_ops.expand_dims(
231              math_ops.cast(categorical_tensor.values, dtypes.int64), -1)
232      ], 1)
233      tensor = sparse_tensor.SparseTensor(
234          indices=indices, values=weight_tensor.values, dense_shape=shape)
235
236    if isinstance(tensor, sparse_tensor.SparseTensor):
237      if tensor.values.dtype == dtypes.float32:
238        sparse_float_names.append(key)
239        sparse_float_indices.append(tensor.indices)
240        sparse_float_values.append(tensor.values)
241        sparse_float_shapes.append(tensor.dense_shape)
242      elif tensor.values.dtype == dtypes.int64:
243        sparse_int_names.append(key)
244        sparse_int_indices.append(tensor.indices)
245        sparse_int_values.append(tensor.values)
246        sparse_int_shapes.append(tensor.dense_shape)
247      else:
248        raise ValueError("Unsupported sparse feature %s with dtype %s." %
249                         (tensor.indices.name, tensor.dtype))
250    else:
251      if tensor.dtype == dtypes.float32:
252        if len(tensor.shape) > 1 and tensor.shape[1] > 1:
253          unstacked = array_ops.unstack(tensor, axis=1)
254          for i in range(len(unstacked)):
255            dense_float_names.append(_FEATURE_NAME_TEMPLATE % (key, i))
256            dense_floats.append(array_ops.reshape(unstacked[i], [-1, 1]))
257        else:
258          dense_float_names.append(key)
259          dense_floats.append(tensor)
260      else:
261        raise ValueError("Unsupported dense feature %s with dtype %s." %
262                         (tensor.name, tensor.dtype))
263  # Feature columns are logically organized into incrementing slots starting
264  # from dense floats, then sparse floats then sparse ints.
265  fc_names = (dense_float_names + sparse_float_names + sparse_int_names)
266  return (fc_names, dense_floats, sparse_float_indices, sparse_float_values,
267          sparse_float_shapes, sparse_int_indices, sparse_int_values,
268          sparse_int_shapes)
269
270
271def _dropout_params(mode, ensemble_stats):
272  """Returns parameters relevant for dropout.
273
274  Args:
275    mode: Train/Eval/Infer
276    ensemble_stats: A TreeEnsembleStatsOp result tuple.
277
278  Returns:
279    Whether to apply dropout and a dropout seed.
280  """
281  if mode == learn.ModeKeys.TRAIN:
282    # Do dropout only during training.
283    apply_dropout = True
284    seed = ensemble_stats.attempted_trees
285  else:
286    seed = -1
287    apply_dropout = False
288  return apply_dropout, seed
289
290
291class GradientBoostedDecisionTreeModel(object):
292  """A GBDT model function."""
293
294  def __init__(self,
295               is_chief,
296               num_ps_replicas,
297               ensemble_handle,
298               center_bias,
299               examples_per_layer,
300               learner_config,
301               features,
302               logits_dimension,
303               loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS,
304               feature_columns=None,
305               use_core_columns=False,
306               output_leaf_index=False,
307               output_leaf_index_modes=None,
308               num_quantiles=100):
309    """Construct a new GradientBoostedDecisionTreeModel function.
310
311    Args:
312      is_chief: Whether to build the chief graph.
313      num_ps_replicas: Number of parameter server replicas, can be 0.
314      ensemble_handle: A handle to the ensemble variable.
315      center_bias: Whether to center the bias before growing trees.
316      examples_per_layer: Number of examples to accumulate before growing a tree
317        layer. It can also be a function that computes the number of examples
318        based on the depth of the layer that's being built.
319      learner_config: A learner config.
320      features: `dict` of `Tensor` objects.
321      logits_dimension: An int, the dimension of logits.
322      loss_reduction: Either `SUM_OVER_NONZERO_WEIGHTS` (mean) or `SUM`.
323      feature_columns: A list of feature columns.
324      use_core_columns: A boolean specifying whether core feature columns are
325        used.
326      output_leaf_index: A boolean variable indicating whether to output leaf
327        index into predictions dictionary.
328      output_leaf_index_modes: A list of modes from (TRAIN, EVAL, INFER) which
329        dictates when leaf indices will be outputted. By default, leaf indices
330        are only outputted in INFER mode.
331      num_quantiles: Number of quantiles to build for numeric feature values.
332
333    Raises:
334      ValueError: if inputs are not valid.
335    """
336    if ensemble_handle is None:
337      raise ValueError("ensemble_handle must be specified.")
338
339    if learner_config is None:
340      raise ValueError("learner_config must be specified.")
341
342    if learner_config.num_classes < 2:
343      raise ValueError("Number of classes must be >=2")
344
345    self._logits_dimension = logits_dimension
346    self._is_chief = is_chief
347    self._num_ps_replicas = num_ps_replicas
348    self._ensemble_handle = ensemble_handle
349    self._center_bias = center_bias
350    self._examples_per_layer = examples_per_layer
351
352    # Check loss reduction value.
353    if (loss_reduction != losses.Reduction.SUM and
354        loss_reduction != losses.Reduction.SUM_OVER_NONZERO_WEIGHTS):
355      raise ValueError(
356          "Invalid loss reduction is provided: %s." % loss_reduction)
357    self._loss_reduction = loss_reduction
358
359    # Fill in the defaults.
360    if (learner_config.multi_class_strategy ==
361        learner_pb2.LearnerConfig.MULTI_CLASS_STRATEGY_UNSPECIFIED):
362      if logits_dimension == 1:
363        learner_config.multi_class_strategy = (
364            learner_pb2.LearnerConfig.TREE_PER_CLASS)
365      else:
366        learner_config.multi_class_strategy = (
367            learner_pb2.LearnerConfig.DIAGONAL_HESSIAN)
368
369    if logits_dimension == 1 or learner_config.multi_class_strategy == (
370        learner_pb2.LearnerConfig.TREE_PER_CLASS):
371      self._gradient_shape = tensor_shape.scalar()
372      self._hessian_shape = tensor_shape.scalar()
373    else:
374      if center_bias:
375        raise ValueError("Center bias should be False for multiclass.")
376
377      self._gradient_shape = tensor_shape.TensorShape([logits_dimension])
378      if (learner_config.multi_class_strategy ==
379          learner_pb2.LearnerConfig.FULL_HESSIAN):
380        self._hessian_shape = tensor_shape.TensorShape(
381            ([logits_dimension, logits_dimension]))
382      else:
383        # Diagonal hessian strategy.
384        self._hessian_shape = tensor_shape.TensorShape(([logits_dimension]))
385    if (learner_config.growing_mode ==
386        learner_pb2.LearnerConfig.GROWING_MODE_UNSPECIFIED):
387      learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER
388
389    if (learner_config.weak_learner_type == learner_pb2.LearnerConfig
390        .OBLIVIOUS_DECISION_TREE and learner_config.pruning_mode == learner_pb2
391        .LearnerConfig.PRUNING_MODE_UNSPECIFIED):
392      learner_config.pruning_mode = learner_pb2.LearnerConfig.PRE_PRUNE
393
394    if (learner_config.pruning_mode ==
395        learner_pb2.LearnerConfig.PRUNING_MODE_UNSPECIFIED):
396      learner_config.pruning_mode = learner_pb2.LearnerConfig.POST_PRUNE
397
398    if (learner_config.weak_learner_type == learner_pb2.LearnerConfig
399        .OBLIVIOUS_DECISION_TREE and
400        learner_config.pruning_mode == learner_pb2.LearnerConfig.POST_PRUNE):
401      raise ValueError(
402          "Post pruning is not implmented for oblivious decision trees.")
403
404    if learner_config.constraints.max_tree_depth == 0:
405      # Use 6 as the default maximum depth.
406      learner_config.constraints.max_tree_depth = 6
407
408    tuner = learner_config.learning_rate_tuner.WhichOneof("tuner")
409    if not tuner:
410      learner_config.learning_rate_tuner.fixed.learning_rate = 0.1
411
412    self._learner_config = learner_config
413    self._feature_columns = feature_columns
414    self._learner_config_serialized = learner_config.SerializeToString()
415    self._num_quantiles = num_quantiles
416    self._max_tree_depth = variables.VariableV1(
417        initial_value=self._learner_config.constraints.max_tree_depth)
418    self._attempted_trees = variables.VariableV1(
419        initial_value=array_ops.zeros([], dtypes.int64),
420        trainable=False,
421        name="attempted_trees")
422    self._finalized_trees = variables.VariableV1(
423        initial_value=array_ops.zeros([], dtypes.int64),
424        trainable=False,
425        name="finalized_trees")
426    if not features:
427      raise ValueError("Features dictionary must be specified.")
428    (fc_names, dense_floats, sparse_float_indices, sparse_float_values,
429     sparse_float_shapes, sparse_int_indices,
430     sparse_int_values, sparse_int_shapes) = extract_features(
431         features, self._feature_columns, use_core_columns)
432    if (learner_config.weak_learner_type == learner_pb2.LearnerConfig
433        .OBLIVIOUS_DECISION_TREE and sparse_float_indices):
434      raise ValueError("Oblivious trees don't handle sparse float features yet."
435                      )
436
437    logging.info("Active Feature Columns: " + str(fc_names))
438    logging.info("Learner config: " + str(learner_config))
439    self._fc_names = fc_names
440    self._dense_floats = dense_floats
441    self._sparse_float_indices = sparse_float_indices
442    self._sparse_float_values = sparse_float_values
443    self._sparse_float_shapes = sparse_float_shapes
444    self._sparse_int_indices = sparse_int_indices
445    self._sparse_int_values = sparse_int_values
446    self._sparse_int_shapes = sparse_int_shapes
447    self._reduce_dim = (
448        self._learner_config.multi_class_strategy ==
449        learner_pb2.LearnerConfig.TREE_PER_CLASS and
450        learner_config.num_classes == 2)
451
452    if output_leaf_index_modes is None:
453      output_leaf_index_modes = [learn.ModeKeys.INFER]
454    elif not all(
455        mode in (learn.ModeKeys.TRAIN, learn.ModeKeys.EVAL,
456                 learn.ModeKeys.INFER) for mode in output_leaf_index_modes):
457      raise ValueError("output_leaf_index_modes should only contain ModeKeys.")
458
459    self._output_leaf_index = output_leaf_index
460    self._output_leaf_index_modes = output_leaf_index_modes
461
462  def _predict_and_return_dict(self, ensemble_handle, ensemble_stamp, mode):
463    """Runs prediction and returns a dictionary of the prediction results.
464
465    Args:
466      ensemble_handle: ensemble resource handle.
467      ensemble_stamp: stamp of ensemble resource.
468      mode: learn.ModeKeys.TRAIN or EVAL or INFER.
469
470    Returns:
471      a dictionary of prediction results -
472        ENSEMBLE_STAMP, PREDICTION, PARTITION_IDS,
473        NUM_LAYER_ATTEMPTED, NUM_TREES_ATTEMPTED.
474    """
475    ensemble_stats = training_ops.tree_ensemble_stats(ensemble_handle,
476                                                      ensemble_stamp)
477    num_handlers = (
478        len(self._dense_floats) + len(self._sparse_float_shapes) + len(
479            self._sparse_int_shapes))
480    # Used during feature selection.
481    used_handlers = model_ops.tree_ensemble_used_handlers(
482        ensemble_handle, ensemble_stamp, num_all_handlers=num_handlers)
483
484    # We don't need dropout info - we can always restore it based on the
485    # seed.
486    apply_dropout, seed = _dropout_params(mode, ensemble_stats)
487    # Make sure ensemble stats run. This will check that the ensemble has
488    # the right stamp.
489    with ops.control_dependencies(ensemble_stats):
490      leaf_index = None
491      if self._output_leaf_index and mode in self._output_leaf_index_modes:
492        predictions, _, leaf_index = (
493            prediction_ops).gradient_trees_prediction_verbose(
494                ensemble_handle,
495                seed,
496                self._dense_floats,
497                self._sparse_float_indices,
498                self._sparse_float_values,
499                self._sparse_float_shapes,
500                self._sparse_int_indices,
501                self._sparse_int_values,
502                self._sparse_int_shapes,
503                learner_config=self._learner_config_serialized,
504                apply_dropout=apply_dropout,
505                apply_averaging=mode != learn.ModeKeys.TRAIN,
506                use_locking=True,
507                center_bias=self._center_bias,
508                reduce_dim=self._reduce_dim)
509      else:
510        leaf_index = None
511        predictions, _ = prediction_ops.gradient_trees_prediction(
512            ensemble_handle,
513            seed,
514            self._dense_floats,
515            self._sparse_float_indices,
516            self._sparse_float_values,
517            self._sparse_float_shapes,
518            self._sparse_int_indices,
519            self._sparse_int_values,
520            self._sparse_int_shapes,
521            learner_config=self._learner_config_serialized,
522            apply_dropout=apply_dropout,
523            apply_averaging=mode != learn.ModeKeys.TRAIN,
524            use_locking=True,
525            center_bias=self._center_bias,
526            reduce_dim=self._reduce_dim)
527      partition_ids = prediction_ops.gradient_trees_partition_examples(
528          ensemble_handle,
529          self._dense_floats,
530          self._sparse_float_indices,
531          self._sparse_float_values,
532          self._sparse_float_shapes,
533          self._sparse_int_indices,
534          self._sparse_int_values,
535          self._sparse_int_shapes,
536          use_locking=True)
537
538    return _make_predictions_dict(ensemble_stamp, predictions, partition_ids,
539                                  ensemble_stats, used_handlers, leaf_index)
540
541  def predict(self, mode):
542    """Returns predictions given the features and mode.
543
544    Args:
545      mode: Mode the graph is running in (train|predict|eval).
546
547    Returns:
548      A dict of predictions tensors.
549
550    Raises:
551      ValueError: if features is not valid.
552    """
553
554    # Use the current ensemble to predict on the current batch of input.
555    # For faster prediction we check if the inputs are on the same device
556    # as the model. If not, we create a copy of the model on the worker.
557    input_deps = (
558        self._dense_floats + self._sparse_float_indices +
559        self._sparse_int_indices)
560    if not input_deps:
561      raise ValueError("No input tensors for prediction.")
562
563    # Get most current model stamp.
564    ensemble_stamp = model_ops.tree_ensemble_stamp_token(self._ensemble_handle)
565
566    # Determine if ensemble is colocated with the inputs.
567    if self._ensemble_handle.device != input_deps[0].device:
568      # Create a local ensemble and get its local stamp.
569      with ops.name_scope("local_ensemble", "TreeEnsembleVariable"):
570        local_ensemble_handle = (
571            gen_model_ops.decision_tree_ensemble_resource_handle_op(
572                self._ensemble_handle.op.name + "/local_ensemble"))
573        create_op = gen_model_ops.create_tree_ensemble_variable(
574            local_ensemble_handle, stamp_token=-1, tree_ensemble_config="")
575        with ops.control_dependencies([create_op]):
576          local_stamp = model_ops.tree_ensemble_stamp_token(
577              local_ensemble_handle)
578
579      # Determine whether the local ensemble is stale and update it if needed.
580      def _refresh_local_ensemble_fn():
581        # Serialize the model from parameter server after reading the inputs.
582        with ops.control_dependencies([input_deps[0]]):
583          (ensemble_stamp, serialized_model) = (
584              model_ops.tree_ensemble_serialize(self._ensemble_handle))
585
586        # Update local ensemble with the serialized model from parameter server.
587        with ops.control_dependencies([create_op]):
588          return model_ops.tree_ensemble_deserialize(
589              local_ensemble_handle,
590              stamp_token=ensemble_stamp,
591              tree_ensemble_config=serialized_model), ensemble_stamp
592
593      refresh_local_ensemble, ensemble_stamp = control_flow_ops.cond(
594          math_ops.not_equal(ensemble_stamp,
595                             local_stamp), _refresh_local_ensemble_fn,
596          lambda: (control_flow_ops.no_op(), ensemble_stamp))
597
598      # Once updated, use the local model for prediction.
599      with ops.control_dependencies([refresh_local_ensemble]):
600        return self._predict_and_return_dict(local_ensemble_handle,
601                                             ensemble_stamp, mode)
602    else:
603      # Use ensemble_handle directly, if colocated.
604      with ops.device(self._ensemble_handle.device):
605        return self._predict_and_return_dict(self._ensemble_handle,
606                                             ensemble_stamp, mode)
607
608  def _get_class_id(self, predictions_dict):
609    # Handle different multiclass strategies.
610    if (self._learner_config.multi_class_strategy ==
611        learner_pb2.LearnerConfig.TREE_PER_CLASS and
612        self._logits_dimension != 1):
613      # Choose the class for which the tree is built (one vs rest).
614      return math_ops.cast(
615          predictions_dict[NUM_TREES_ATTEMPTED] % self._logits_dimension,
616          dtypes.int32)
617    return constant_op.constant(-1, dtype=dtypes.int32)
618
619  def update_stats(self, loss, predictions_dict, gradients=None, hessians=None):
620    """Update the accumulators with stats from this batch.
621
622    Args:
623      loss: A scalar tensor representing average loss of examples.
624      predictions_dict: Dictionary of Rank 2 `Tensor` representing information
625          about predictions per example.
626      gradients: A tensor with the gradients with the respect to logits from
627        predictions_dict. If not provided, tensorflow will do
628        autodifferentiation.
629      hessians: A tensor with the hessians with the respect to logits from
630        predictions_dict. If not provided, tensorflow will do
631        autodifferentiation.
632
633    Returns:
634      Three values:
635        - An op that adds a new tree to the ensemble, and
636        - An op that increments the stamp but removes all the trees and resets
637            the handlers. This can be used to reset the state of the ensemble.
638        - A dict containing the training state.
639
640    Raises:
641      ValueError: if inputs are not valid.
642    """
643    # Get the worker device from input dependencies.
644    input_deps = (
645        self._dense_floats + self._sparse_float_indices +
646        self._sparse_int_indices)
647    worker_device = input_deps[0].device
648
649    # Get tensors relevant for training and form the loss.
650    predictions = predictions_dict[PREDICTIONS]
651    partition_ids = predictions_dict[PARTITION_IDS]
652    ensemble_stamp = predictions_dict[ENSEMBLE_STAMP]
653    if gradients is None:
654      gradients = gradients_impl.gradients(
655          loss,
656          predictions,
657          name="Gradients",
658          colocate_gradients_with_ops=False,
659          gate_gradients=0,
660          aggregation_method=None)[0]
661    strategy = self._learner_config.multi_class_strategy
662
663    class_id = self._get_class_id(predictions_dict)
664    # Handle different multiclass strategies.
665    if strategy == learner_pb2.LearnerConfig.TREE_PER_CLASS:
666      # We build one vs rest trees.
667      if self._logits_dimension == 1:
668        # We have only 1 score, gradients is of shape [batch, 1].
669        if hessians is None:
670          hessians = gradients_impl.gradients(
671              gradients,
672              predictions,
673              name="Hessian",
674              colocate_gradients_with_ops=False,
675              gate_gradients=0,
676              aggregation_method=None)[0]
677
678        squeezed_gradients = array_ops.squeeze(gradients, axis=[1])
679        squeezed_hessians = array_ops.squeeze(hessians, axis=[1])
680      else:
681        if hessians is not None:
682          raise ValueError("Providing hessians is not yet supported here.")
683        hessian_list = self._diagonal_hessian(gradients, predictions)
684        # Assemble hessian list into a tensor.
685        hessians = array_ops.stack(hessian_list, axis=1)
686        # Use class id tensor to get the column with that index from gradients
687        # and hessians.
688        squeezed_gradients = array_ops.squeeze(
689            _get_column_by_index(gradients, class_id))
690        squeezed_hessians = array_ops.squeeze(
691            _get_column_by_index(hessians, class_id))
692    else:
693      if hessians is not None:
694        raise ValueError("Providing hessians is not yet supported here.")
695      # Other multiclass strategies.
696      if strategy == learner_pb2.LearnerConfig.FULL_HESSIAN:
697        hessian_list = self._full_hessian(gradients, predictions)
698      else:
699        # Diagonal hessian strategy.
700        hessian_list = self._diagonal_hessian(gradients, predictions)
701
702      squeezed_gradients = gradients
703      hessians = array_ops.stack(hessian_list, axis=1)
704      squeezed_hessians = hessians
705
706    # Get the weights for each example for quantiles calculation,
707    weights = self._get_weights(self._hessian_shape, squeezed_hessians)
708
709    # Create all handlers ensuring resources are evenly allocated across PS.
710    fc_name_idx = 0
711    handlers = []
712    init_stamp_token = constant_op.constant(0, dtype=dtypes.int64)
713    l1_regularization = constant_op.constant(
714        self._learner_config.regularization.l1, dtypes.float32)
715    l2_regularization = constant_op.constant(
716        self._learner_config.regularization.l2, dtypes.float32)
717    tree_complexity_regularization = constant_op.constant(
718        self._learner_config.regularization.tree_complexity, dtypes.float32)
719    min_node_weight = constant_op.constant(
720        self._learner_config.constraints.min_node_weight, dtypes.float32)
721    loss_uses_sum_reduction = self._loss_reduction == losses.Reduction.SUM
722    loss_uses_sum_reduction = constant_op.constant(loss_uses_sum_reduction)
723    weak_learner_type = constant_op.constant(
724        self._learner_config.weak_learner_type)
725    num_quantiles = self._num_quantiles
726    epsilon = 1.0 / num_quantiles
727    strategy_tensor = constant_op.constant(strategy)
728    with ops.device(self._get_replica_device_setter(worker_device)):
729      # Create handlers for dense float columns
730      for dense_float_column_idx in range(len(self._dense_floats)):
731        fc_name = self._fc_names[fc_name_idx]
732        handlers.append(
733            ordinal_split_handler.DenseSplitHandler(
734                l1_regularization=l1_regularization,
735                l2_regularization=l2_regularization,
736                tree_complexity_regularization=tree_complexity_regularization,
737                min_node_weight=min_node_weight,
738                feature_column_group_id=constant_op.constant(
739                    dense_float_column_idx),
740                epsilon=epsilon,
741                num_quantiles=num_quantiles,
742                dense_float_column=self._dense_floats[dense_float_column_idx],
743                name=fc_name,
744                gradient_shape=self._gradient_shape,
745                hessian_shape=self._hessian_shape,
746                multiclass_strategy=strategy_tensor,
747                init_stamp_token=init_stamp_token,
748                loss_uses_sum_reduction=loss_uses_sum_reduction,
749                weak_learner_type=weak_learner_type,
750            ))
751        fc_name_idx += 1
752
753      # Create handlers for sparse float columns.
754      for sparse_float_column_idx in range(len(self._sparse_float_indices)):
755        fc_name = self._fc_names[fc_name_idx]
756        handlers.append(
757            ordinal_split_handler.SparseSplitHandler(
758                l1_regularization=l1_regularization,
759                l2_regularization=l2_regularization,
760                tree_complexity_regularization=tree_complexity_regularization,
761                min_node_weight=min_node_weight,
762                feature_column_group_id=constant_op.constant(
763                    sparse_float_column_idx),
764                epsilon=epsilon,
765                num_quantiles=num_quantiles,
766                sparse_float_column=sparse_tensor.SparseTensor(
767                    self._sparse_float_indices[sparse_float_column_idx],
768                    self._sparse_float_values[sparse_float_column_idx],
769                    self._sparse_float_shapes[sparse_float_column_idx]),
770                name=fc_name,
771                gradient_shape=self._gradient_shape,
772                hessian_shape=self._hessian_shape,
773                multiclass_strategy=strategy_tensor,
774                init_stamp_token=init_stamp_token,
775                loss_uses_sum_reduction=loss_uses_sum_reduction))
776        fc_name_idx += 1
777
778      # Create handlers for sparse int columns.
779      for sparse_int_column_idx in range(len(self._sparse_int_indices)):
780        fc_name = self._fc_names[fc_name_idx]
781        handlers.append(
782            categorical_split_handler.EqualitySplitHandler(
783                l1_regularization=l1_regularization,
784                l2_regularization=l2_regularization,
785                tree_complexity_regularization=tree_complexity_regularization,
786                min_node_weight=min_node_weight,
787                feature_column_group_id=constant_op.constant(
788                    sparse_int_column_idx),
789                sparse_int_column=sparse_tensor.SparseTensor(
790                    self._sparse_int_indices[sparse_int_column_idx],
791                    self._sparse_int_values[sparse_int_column_idx],
792                    self._sparse_int_shapes[sparse_int_column_idx]),
793                name=fc_name,
794                gradient_shape=self._gradient_shape,
795                hessian_shape=self._hessian_shape,
796                multiclass_strategy=strategy_tensor,
797                init_stamp_token=init_stamp_token,
798                loss_uses_sum_reduction=loss_uses_sum_reduction,
799                weak_learner_type=weak_learner_type))
800        fc_name_idx += 1
801
802      # Create ensemble stats variables.
803      num_layer_examples = variables.VariableV1(
804          initial_value=array_ops.zeros([], dtypes.int64),
805          name="num_layer_examples",
806          trainable=False)
807      num_layer_steps = variables.VariableV1(
808          initial_value=array_ops.zeros([], dtypes.int64),
809          name="num_layer_steps",
810          trainable=False)
811      num_layers = variables.VariableV1(
812          initial_value=array_ops.zeros([], dtypes.int64),
813          name="num_layers",
814          trainable=False)
815      active_tree = variables.VariableV1(
816          initial_value=array_ops.zeros([], dtypes.int64),
817          name="active_tree",
818          trainable=False)
819      active_layer = variables.VariableV1(
820          initial_value=array_ops.zeros([], dtypes.int64),
821          name="active_layer",
822          trainable=False)
823      # Variable that becomes false once bias centering is done.
824      continue_centering = variables.VariableV1(
825          initial_value=self._center_bias,
826          name="continue_centering",
827          trainable=False)
828      # Create bias stats accumulator.
829      bias_stats_accumulator = stats_accumulator_ops.StatsAccumulator(
830          stamp_token=0,
831          gradient_shape=self._gradient_shape,
832          hessian_shape=self._hessian_shape,
833          name="BiasAccumulator")
834      # Create steps accumulator.
835      steps_accumulator = stats_accumulator_ops.StatsAccumulator(
836          stamp_token=0,
837          gradient_shape=tensor_shape.scalar(),
838          hessian_shape=tensor_shape.scalar(),
839          name="StepsAccumulator")
840    # Create ensemble stats summaries.
841    summary.scalar("layer_stats/num_examples", num_layer_examples)
842    summary.scalar("layer_stats/num_steps", num_layer_steps)
843    summary.scalar("ensemble_stats/active_tree", active_tree)
844    summary.scalar("ensemble_stats/active_layer", active_layer)
845
846    # Update bias stats.
847    stats_update_ops = []
848
849    stats_update_ops.append(
850        control_flow_ops.cond(
851            continue_centering,
852            self._make_update_bias_stats_fn(ensemble_stamp, predictions,
853                                            gradients, bias_stats_accumulator,
854                                            hessians), control_flow_ops.no_op))
855
856    # Update handler stats.
857    handler_reads = collections.OrderedDict()
858    for handler in handlers:
859      handler_reads[handler] = handler.scheduled_reads()
860
861    handler_results = batch_ops_utils.run_handler_scheduled_ops(
862        handler_reads, ensemble_stamp, worker_device)
863    per_handler_updates = collections.OrderedDict()
864    # Two values per handler. First one is if the handler is active for the
865    # current layer. The second one is if the handler is going to be active
866    # for the next layer.
867    subsampling_type = self._learner_config.WhichOneof("feature_fraction")
868    if subsampling_type == "feature_fraction_per_level":
869      seed = predictions_dict[NUM_LAYERS_ATTEMPTED]
870      active_handlers_current_layer = stateless.stateless_random_uniform(
871          shape=[len(handlers)], seed=[seed, 1])
872      active_handlers_next_layer = stateless.stateless_random_uniform(
873          shape=[len(handlers)], seed=[seed + 1, 1])
874      active_handlers = array_ops.stack(
875          [active_handlers_current_layer, active_handlers_next_layer], axis=1)
876      active_handlers = (
877          active_handlers < self._learner_config.feature_fraction_per_level)
878    elif subsampling_type == "feature_fraction_per_tree":
879      seed = predictions_dict[NUM_TREES_ATTEMPTED]
880      active_handlers_current_layer = stateless.stateless_random_uniform(
881          shape=[len(handlers)], seed=[seed, 2])
882      active_handlers_current_layer = (
883          active_handlers_current_layer <
884          self._learner_config.feature_fraction_per_tree)
885      active_handlers = array_ops.stack(
886          [
887              active_handlers_current_layer,
888              array_ops.ones([len(handlers)], dtype=dtypes.bool)
889          ],
890          axis=1)
891    else:
892      active_handlers = array_ops.ones([len(handlers), 2], dtype=dtypes.bool)
893
894    if self._learner_config.constraints.max_number_of_unique_feature_columns:
895      target = (
896          self._learner_config.constraints.max_number_of_unique_feature_columns)
897
898      def _feature_selection_active_handlers():
899        # The active list for current and the next iteration.
900        used_handlers = array_ops.reshape(predictions_dict[USED_HANDLERS_MASK],
901                                          [-1, 1])
902        used_handlers = array_ops.concat([used_handlers, used_handlers], axis=1)
903        return math_ops.logical_and(used_handlers, active_handlers)
904
905      active_handlers = (
906          control_flow_ops.cond(predictions_dict[NUM_USED_HANDLERS] >= target,
907                                _feature_selection_active_handlers,
908                                lambda: active_handlers))
909
910    # Prepare empty gradients and hessians when handlers are not ready.
911    empty_hess_shape = [1] + self._hessian_shape.as_list()
912    empty_grad_shape = [1] + self._gradient_shape.as_list()
913
914    empty_gradients = constant_op.constant_v1(
915        [], dtype=dtypes.float32, shape=empty_grad_shape)
916    empty_hessians = constant_op.constant_v1(
917        [], dtype=dtypes.float32, shape=empty_hess_shape)
918
919    active_handlers = array_ops.unstack(active_handlers, axis=0)
920    for handler_idx in range(len(handlers)):
921      handler = handlers[handler_idx]
922      is_active = active_handlers[handler_idx]
923      updates, scheduled_updates = handler.update_stats(
924          ensemble_stamp, partition_ids, squeezed_gradients, squeezed_hessians,
925          empty_gradients, empty_hessians, weights, is_active,
926          handler_results[handler])
927      stats_update_ops.append(updates)
928      per_handler_updates[handler] = scheduled_updates
929
930    update_results = batch_ops_utils.run_handler_scheduled_ops(
931        per_handler_updates, ensemble_stamp, worker_device)
932    for update in update_results.values():
933      stats_update_ops += update
934
935    training_state = GBDTTrainingState(
936        num_layer_examples=num_layer_examples,
937        num_layer_steps=num_layer_steps,
938        num_layers=num_layers,
939        active_tree=active_tree,
940        active_layer=active_layer,
941        continue_centering=continue_centering,
942        bias_stats_accumulator=bias_stats_accumulator,
943        steps_accumulator=steps_accumulator,
944        handlers=handlers)
945
946    reset_op = control_flow_ops.no_op()
947    if self._is_chief:
948      # Advance the ensemble stamp to throw away staggered workers.
949      stamp_token, _ = model_ops.tree_ensemble_serialize(self._ensemble_handle)
950      next_stamp_token = stamp_token + 1
951
952      reset_ops = []
953      for handler in handlers:
954        reset_ops.append(handler.reset(stamp_token, next_stamp_token))
955      if self._center_bias:
956        reset_ops.append(
957            bias_stats_accumulator.flush(stamp_token, next_stamp_token))
958      reset_ops.append(steps_accumulator.flush(stamp_token, next_stamp_token))
959      reset_ops.append(self._finalized_trees.assign(0).op)
960      reset_ops.append(self._attempted_trees.assign(0).op)
961      reset_ops.append(
962          model_ops.tree_ensemble_deserialize(
963              self._ensemble_handle,
964              stamp_token=next_stamp_token,
965              tree_ensemble_config="",
966              name="reset_gbdt"))
967
968      reset_op = control_flow_ops.group([reset_ops])
969
970    return stats_update_ops, reset_op, training_state
971
972  def increment_step_counter_and_maybe_update_ensemble(self, predictions_dict,
973                                                       training_state):
974    """Increments number of visited examples and grows the ensemble.
975
976    If the number of visited examples reaches the target examples_per_layer,
977    ensemble is updated.
978
979    Args:
980      predictions_dict: Dictionary of Rank 2 `Tensor` representing information
981          about predictions per example.
982      training_state: `dict` returned by update_stats.
983
984    Returns:
985      An op that updates the counters and potientially grows the ensemble.
986    """
987    batch_size = math_ops.cast(
988        array_ops.shape(predictions_dict[PREDICTIONS])[0], dtypes.float32)
989    ensemble_stamp = predictions_dict[ENSEMBLE_STAMP]
990    # Accumulate a step after updating stats.
991
992    steps_accumulator = training_state.steps_accumulator
993    num_layer_examples = training_state.num_layer_examples
994    num_layer_steps = training_state.num_layer_steps
995    active_layer = training_state.active_layer
996    add_step_op = steps_accumulator.add(
997        ensemble_stamp, [0], [[0, 0]], [batch_size], [1.0])
998
999    # After adding the step, decide if further processing is needed.
1000    ensemble_update_ops = [add_step_op]
1001    class_id = self._get_class_id(predictions_dict)
1002
1003    with ops.control_dependencies([add_step_op]):
1004      if self._is_chief:
1005        dropout_seed = predictions_dict[NUM_TREES_ATTEMPTED]
1006
1007        # Get accumulated steps and examples for the current layer.
1008        _, _, _, _, acc_examples, acc_steps = (
1009            steps_accumulator.saveable.serialize())
1010        acc_examples = math_ops.cast(acc_examples[0], dtypes.int64)
1011        acc_steps = math_ops.cast(acc_steps[0], dtypes.int64)
1012        ensemble_update_ops.append(
1013            num_layer_examples.assign(acc_examples))
1014        ensemble_update_ops.append(num_layer_steps.assign(acc_steps))
1015        # Determine whether we need to update tree ensemble.
1016        examples_per_layer = self._examples_per_layer
1017        if callable(examples_per_layer):
1018          examples_per_layer = examples_per_layer(active_layer)
1019        ensemble_update_ops.append(
1020            control_flow_ops.cond(
1021                acc_examples >= examples_per_layer,
1022                self.make_update_ensemble_fn(ensemble_stamp, training_state,
1023                                             dropout_seed, class_id),
1024                control_flow_ops.no_op))
1025
1026    # Note, the loss is calculated from the prediction considering dropouts, so
1027    # that the value might look staggering over steps when the dropout ratio is
1028    # high. eval_loss might be referred instead in the aspect of convergence.
1029    return control_flow_ops.group(*ensemble_update_ops)
1030
1031  def make_update_ensemble_fn(self, ensemble_stamp, training_state,
1032                              dropout_seed, class_id):
1033    """A method to create the function which updates the tree ensemble."""
1034    # Determine learning rate.
1035    learning_rate_tuner = self._learner_config.learning_rate_tuner.WhichOneof(
1036        "tuner")
1037    if learning_rate_tuner == "fixed" or learning_rate_tuner == "dropout":
1038      tuner = getattr(self._learner_config.learning_rate_tuner,
1039                      learning_rate_tuner)
1040      learning_rate = tuner.learning_rate
1041    else:
1042      # TODO(nponomareva, soroush) do the line search.
1043      raise ValueError("Line search learning rate is not yet supported.")
1044
1045    def _update_ensemble():
1046      """A method to update the tree ensemble."""
1047      # Get next stamp token.
1048      next_ensemble_stamp = ensemble_stamp + 1
1049      # Finalize bias stats.
1050      _, _, _, bias_grads, bias_hess = (
1051          training_state.bias_stats_accumulator.flush(ensemble_stamp,
1052                                                      next_ensemble_stamp))
1053
1054      # Finalize handler splits.
1055      are_splits_ready_list = []
1056      partition_ids_list = []
1057      gains_list = []
1058      split_info_list = []
1059
1060      for handler in training_state.handlers:
1061        (are_splits_ready,
1062         partition_ids, gains, split_info) = handler.make_splits(
1063             ensemble_stamp, next_ensemble_stamp, class_id)
1064        are_splits_ready_list.append(are_splits_ready)
1065        partition_ids_list.append(partition_ids)
1066        gains_list.append(gains)
1067        split_info_list.append(split_info)
1068      # Stack all the inputs to one tensor per type.
1069      # This is a workaround for the slowness of graph building in tf.cond.
1070      # See (b/36554864).
1071      split_sizes = array_ops.reshape(
1072          array_ops.shape_n(partition_ids_list), [len(partition_ids_list)])
1073      partition_ids = array_ops.concat(partition_ids_list, axis=0)
1074      gains = array_ops.concat(gains_list, axis=0)
1075      split_infos = array_ops.concat(split_info_list, axis=0)
1076
1077      # Determine if all splits are ready.
1078      are_all_splits_ready = math_ops.reduce_all(
1079          array_ops.stack(
1080              are_splits_ready_list, axis=0, name="stack_handler_readiness"))
1081
1082      # Define bias centering update operation.
1083      def _center_bias_fn():
1084        # Center tree ensemble bias.
1085        delta_updates = array_ops.where(bias_hess > 0, -bias_grads / bias_hess,
1086                                        array_ops.zeros_like(bias_grads))
1087        center_bias = training_ops.center_tree_ensemble_bias(
1088            tree_ensemble_handle=self._ensemble_handle,
1089            stamp_token=ensemble_stamp,
1090            next_stamp_token=next_ensemble_stamp,
1091            delta_updates=delta_updates,
1092            learner_config=self._learner_config_serialized)
1093        return training_state.continue_centering.assign(center_bias)
1094
1095      # Define ensemble growing operations.
1096      def _grow_ensemble_ready_fn():
1097        # Grow the ensemble given the current candidates.
1098        sizes = array_ops.unstack(split_sizes)
1099        partition_ids_list = list(array_ops.split(partition_ids, sizes, axis=0))
1100        # When using the oblivious decision tree as weak learner, it produces
1101        # one gain and one split per handler and not number of partitions.
1102        if self._learner_config.weak_learner_type == (
1103            learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE):
1104          sizes = len(training_state.handlers)
1105
1106        gains_list = list(array_ops.split(gains, sizes, axis=0))
1107        split_info_list = list(array_ops.split(split_infos, sizes, axis=0))
1108        return training_ops.grow_tree_ensemble(
1109            tree_ensemble_handle=self._ensemble_handle,
1110            stamp_token=ensemble_stamp,
1111            next_stamp_token=next_ensemble_stamp,
1112            learning_rate=learning_rate,
1113            partition_ids=partition_ids_list,
1114            gains=gains_list,
1115            splits=split_info_list,
1116            learner_config=self._learner_config_serialized,
1117            dropout_seed=dropout_seed,
1118            center_bias=self._center_bias,
1119            max_tree_depth=self._max_tree_depth,
1120            weak_learner_type=self._learner_config.weak_learner_type)
1121
1122      def _grow_ensemble_not_ready_fn():
1123        # Don't grow the ensemble, just update the stamp.
1124        return training_ops.grow_tree_ensemble(
1125            tree_ensemble_handle=self._ensemble_handle,
1126            stamp_token=ensemble_stamp,
1127            next_stamp_token=next_ensemble_stamp,
1128            learning_rate=0,
1129            partition_ids=[],
1130            gains=[],
1131            splits=[],
1132            learner_config=self._learner_config_serialized,
1133            dropout_seed=dropout_seed,
1134            center_bias=self._center_bias,
1135            max_tree_depth=self._max_tree_depth,
1136            weak_learner_type=self._learner_config.weak_learner_type)
1137
1138      def _grow_ensemble_fn():
1139        # Conditionally grow an ensemble depending on whether the splits
1140        # from all the handlers are ready.
1141        return control_flow_ops.cond(are_all_splits_ready,
1142                                     _grow_ensemble_ready_fn,
1143                                     _grow_ensemble_not_ready_fn)
1144
1145      # Update ensemble.
1146      update_ops = [are_all_splits_ready]
1147      if self._center_bias:
1148        update_model = control_flow_ops.cond(training_state.continue_centering,
1149                                             _center_bias_fn, _grow_ensemble_fn)
1150      else:
1151        update_model = _grow_ensemble_fn()
1152      update_ops.append(update_model)
1153
1154      # Update ensemble stats.
1155      with ops.control_dependencies([update_model]):
1156        stats = training_ops.tree_ensemble_stats(
1157            self._ensemble_handle, stamp_token=next_ensemble_stamp)
1158        update_ops.append(self._finalized_trees.assign(stats.num_trees))
1159        update_ops.append(self._attempted_trees.assign(stats.attempted_trees))
1160        update_ops.append(training_state.num_layers.assign(stats.num_layers))
1161        update_ops.append(training_state.active_tree.assign(stats.active_tree))
1162        update_ops.append(
1163            training_state.active_layer.assign(stats.active_layer))
1164
1165      # Flush step stats.
1166      update_ops.extend(
1167          training_state.steps_accumulator.flush(ensemble_stamp,
1168                                                 next_ensemble_stamp))
1169      return control_flow_ops.group(*update_ops, name="update_ensemble")
1170
1171    return _update_ensemble
1172
1173  def get_number_of_trees_tensor(self):
1174    return self._finalized_trees, self._attempted_trees
1175
1176  def get_max_tree_depth(self):
1177    return self._max_tree_depth
1178
1179  def train(self, loss, predictions_dict, labels, gradients=None,
1180            hessians=None):
1181    """Updates the accumalator stats and grows the ensemble.
1182
1183    Args:
1184      loss: A scalar tensor representing average loss of examples.
1185      predictions_dict: Dictionary of Rank 2 `Tensor` representing information
1186          about predictions per example.
1187      labels: Rank 2 `Tensor` representing labels per example. Has no effect
1188          on the training and is only kept for backward compatibility.
1189      gradients: A tensor with the gradients with the respect to logits from
1190        predictions_dict. If not provided, tensorflow will do
1191        autodifferentiation.
1192      hessians: A tensor with the hessians with the respect to logits from
1193        predictions_dict. If not provided, tensorflow will do
1194        autodifferentiation.
1195
1196    Returns:
1197      An op that adds a new tree to the ensemble.
1198
1199    Raises:
1200      ValueError: if inputs are not valid.
1201    """
1202    del labels  # unused; kept for backward compatibility.
1203    update_op, _, training_state = self.update_stats(loss, predictions_dict,
1204                                                     gradients, hessians)
1205    with ops.control_dependencies(update_op):
1206      return self.increment_step_counter_and_maybe_update_ensemble(
1207          predictions_dict, training_state)
1208
1209  def _get_weights(self, hessian_shape, hessians):
1210    """Derives weights to be used based on hessians and multiclass strategy."""
1211    if hessian_shape == tensor_shape.scalar():
1212      # This is tree per class.
1213      weights = hessians
1214    elif len(hessian_shape.dims) == 1:
1215      # This is diagonal hessian.
1216      weights = math_ops.reduce_sum(hessians, axis=1)
1217    else:
1218      # This is full hessian.
1219      weights = math_ops.trace(hessians)
1220    return weights
1221
1222  def _full_hessian(self, grads, predictions):
1223    """Prepares hessians for full-hessian multiclass strategy."""
1224    # Because of
1225    # https://github.com/tensorflow/tensorflow/issues/675, we can't just
1226    # compute the full hessian with a single call to gradients, but instead
1227    # must compute it row-by-row.
1228    gradients_list = array_ops.unstack(
1229        grads, num=self._logits_dimension, axis=1)
1230    hessian_rows = []
1231
1232    for row in range(self._logits_dimension):
1233      # If current row is i, K is number of classes,each row returns a tensor of
1234      # size batch_size x K representing for each example dx_i dx_1, dx_i dx_2
1235      # etc dx_i dx_K
1236      hessian_row = gradients_impl.gradients(
1237          gradients_list[row],
1238          predictions,
1239          name="Hessian_%d" % row,
1240          colocate_gradients_with_ops=False,
1241          gate_gradients=0,
1242          aggregation_method=None)
1243
1244      # hessian_row is of dimension 1, batch_size, K, => trim first dimension
1245      # to get batch_size x K
1246      hessian_row = array_ops.squeeze(array_ops.unstack(hessian_row), [0])
1247      hessian_rows.append(hessian_row)
1248    return hessian_rows
1249
1250  def _diagonal_hessian(self, grads, predictions):
1251    """Prepares hessians for diagonal-hessian multiclass mode."""
1252    diag_hessian_list = []
1253
1254    gradients_list = array_ops.unstack(
1255        grads, num=self._logits_dimension, axis=1)
1256
1257    for row, row_grads in enumerate(gradients_list):
1258      # If current row is i, K is number of classes,each row returns a tensor of
1259      # size batch_size x K representing for each example dx_i dx_1, dx_1 dx_2
1260      # etc dx_i dx_K
1261      hessian_row = gradients_impl.gradients(
1262          row_grads,
1263          predictions,
1264          name="Hessian_%d" % row,
1265          colocate_gradients_with_ops=False,
1266          gate_gradients=0,
1267          aggregation_method=None)
1268
1269      # hessian_row is of dimension 1, batch_size, K, => trim first dimension
1270      # to get batch_size x K
1271      hessian_row = array_ops.squeeze(array_ops.unstack(hessian_row), [0])
1272
1273      # Get dx_i^2 for the whole batch.
1274      elem = array_ops.transpose(hessian_row)[row]
1275      diag_hessian_list.append(elem)
1276
1277    return diag_hessian_list
1278
1279  def _get_replica_device_setter(self, worker_device):
1280    """Creates a replica device setter."""
1281    ps_tasks = self._num_ps_replicas
1282    ps_ops = list(device_setter.STANDARD_PS_OPS)
1283    ps_ops.extend([
1284        "DecisionTreeEnsembleResourceHandleOp",
1285        "StatsAccumulatorScalarResourceHandleOp",
1286        "StatsAccumulatorTensorResourceHandleOp",
1287    ])
1288    ps_strategy = _OpRoundRobinStrategy(ps_ops, ps_tasks)
1289    return device_setter.replica_device_setter(
1290        worker_device=worker_device,
1291        ps_tasks=ps_tasks,
1292        merge_devices=True,
1293        ps_ops=ps_ops,
1294        ps_strategy=ps_strategy)
1295
1296  def _make_update_bias_stats_fn(self,
1297                                 ensemble_stamp,
1298                                 predictions,
1299                                 gradients,
1300                                 bias_stats_accumulator,
1301                                 hessians=None):
1302    """A method to create the function which updates the bias stats."""
1303
1304    def _update_bias_stats():
1305      """A method to update the bias stats."""
1306      # Get reduced gradients and hessians.
1307      grads_sum = math_ops.reduce_sum(gradients, 0)
1308      if hessians is not None:
1309        hess = hessians
1310      else:
1311        hess = gradients_impl.gradients(
1312            grads_sum,
1313            predictions,
1314            name="Hessians",
1315            colocate_gradients_with_ops=False,
1316            gate_gradients=0,
1317            aggregation_method=None)[0]
1318      hess_sum = math_ops.reduce_sum(hess, 0)
1319
1320      # Accumulate gradients and hessians.
1321      partition_ids = math_ops.range(self._logits_dimension)
1322      feature_ids = array_ops.zeros(
1323          [self._logits_dimension, 2], dtype=dtypes.int64)
1324
1325      add_stats_op = bias_stats_accumulator.add(
1326          ensemble_stamp, partition_ids, feature_ids, grads_sum, hess_sum)
1327      return control_flow_ops.group(*[add_stats_op], name="update_bias_stats")
1328
1329    return _update_bias_stats
1330