• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""GTFlow Estimator definition."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22
23from tensorflow.contrib.boosted_trees.estimator_batch import model
24from tensorflow.contrib.boosted_trees.python.utils import losses
25from tensorflow.contrib.learn.python.learn.estimators import estimator
26from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
27from tensorflow.python.estimator.canned import head as core_head_lib
28from tensorflow.python.estimator import estimator as core_estimator
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops.losses import losses as core_losses
31from tensorflow.contrib.boosted_trees.estimator_batch import custom_loss_head
32from tensorflow.python.ops import array_ops
33
34# ================== Old estimator interface===================================
35# The estimators below were designed for old feature columns and old estimator
36# interface. They can be used with new feature columns and losses by setting
37# use_core_libs = True.
38
39
40class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
41  """An estimator using gradient boosted decision trees."""
42
43  def __init__(self,
44               learner_config,
45               examples_per_layer,
46               n_classes=2,
47               num_trees=None,
48               feature_columns=None,
49               weight_column_name=None,
50               model_dir=None,
51               config=None,
52               label_keys=None,
53               feature_engineering_fn=None,
54               logits_modifier_function=None,
55               center_bias=True,
56               use_core_libs=False,
57               output_leaf_index=False,
58               override_global_step_value=None,
59               num_quantiles=100):
60    """Initializes a GradientBoostedDecisionTreeClassifier estimator instance.
61
62    Args:
63      learner_config: A config for the learner.
64      examples_per_layer: Number of examples to accumulate before growing a
65        layer. It can also be a function that computes the number of examples
66        based on the depth of the layer that's being built.
67      n_classes: Number of classes in the classification.
68      num_trees: An int, number of trees to build.
69      feature_columns: A list of feature columns.
70      weight_column_name: Name of the column for weights, or None if not
71        weighted.
72      model_dir: Directory for model exports, etc.
73      config: `RunConfig` object to configure the runtime settings.
74      label_keys: Optional list of strings with size `[n_classes]` defining the
75        label vocabulary. Only supported for `n_classes` > 2.
76      feature_engineering_fn: Feature engineering function. Takes features and
77        labels which are the output of `input_fn` and returns features and
78        labels which will be fed into the model.
79      logits_modifier_function: A modifier function for the logits.
80      center_bias: Whether a separate tree should be created for first fitting
81        the bias.
82      use_core_libs: Whether feature columns and loss are from the core (as
83        opposed to contrib) version of tensorflow.
84      output_leaf_index: whether to output leaf indices along with predictions
85        during inference. The leaf node indexes are available in predictions
86        dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is
87        [batch_size, num_trees]. For example, result_iter =
88        classifier.predict(...)
89        for result_dict in result_iter: # access leaf index list by
90          result_dict["leaf_index"] # which contains one leaf index per tree
91      override_global_step_value: If after the training is done, global step
92        value must be reset to this value. This should be used to reset global
93        step to a number > number of steps used to train the current ensemble.
94        For example, the usual way is to train a number of trees and set a very
95        large number of training steps. When the training is done (number of
96        trees were trained), this parameter can be used to set the global step
97        to a large value, making it look like that number of training steps ran.
98        If None, no override of global step will happen.
99      num_quantiles: Number of quantiles to build for numeric feature values.
100
101    Raises:
102      ValueError: If learner_config is not valid.
103    """
104    if n_classes > 2:
105      # For multi-class classification, use our loss implementation that
106      # supports second order derivative.
107      def loss_fn(labels, logits, weights=None):
108        result = losses.per_example_maxent_loss(
109            labels=labels,
110            logits=logits,
111            weights=weights,
112            num_classes=n_classes)
113        return math_ops.reduce_mean(result[0])
114    else:
115      loss_fn = None
116    head = head_lib.multi_class_head(
117        n_classes=n_classes,
118        weight_column_name=weight_column_name,
119        enable_centered_bias=False,
120        loss_fn=loss_fn,
121        label_keys=label_keys)
122    if learner_config.num_classes == 0:
123      learner_config.num_classes = n_classes
124    elif learner_config.num_classes != n_classes:
125      raise ValueError("n_classes (%d) doesn't match learner_config (%d)." %
126                       (learner_config.num_classes, n_classes))
127    super(GradientBoostedDecisionTreeClassifier, self).__init__(
128        model_fn=model.model_builder,
129        params={
130            'head': head,
131            'feature_columns': feature_columns,
132            'learner_config': learner_config,
133            'num_trees': num_trees,
134            'weight_column_name': weight_column_name,
135            'examples_per_layer': examples_per_layer,
136            'center_bias': center_bias,
137            'logits_modifier_function': logits_modifier_function,
138            'use_core_libs': use_core_libs,
139            'output_leaf_index': output_leaf_index,
140            'override_global_step_value': override_global_step_value,
141            'num_quantiles': num_quantiles,
142        },
143        model_dir=model_dir,
144        config=config,
145        feature_engineering_fn=feature_engineering_fn)
146
147
148class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
149  """An estimator using gradient boosted decision trees."""
150
151  def __init__(self,
152               learner_config,
153               examples_per_layer,
154               label_dimension=1,
155               num_trees=None,
156               feature_columns=None,
157               label_name=None,
158               weight_column_name=None,
159               model_dir=None,
160               config=None,
161               feature_engineering_fn=None,
162               logits_modifier_function=None,
163               center_bias=True,
164               use_core_libs=False,
165               output_leaf_index=False,
166               override_global_step_value=None,
167               num_quantiles=100):
168    """Initializes a GradientBoostedDecisionTreeRegressor estimator instance.
169
170    Args:
171      learner_config: A config for the learner.
172      examples_per_layer: Number of examples to accumulate before growing a
173        layer. It can also be a function that computes the number of examples
174        based on the depth of the layer that's being built.
175      label_dimension: Number of regression labels per example. This is the size
176        of the last dimension of the labels `Tensor` (typically, this has shape
177        `[batch_size, label_dimension]`).
178      num_trees: An int, number of trees to build.
179      feature_columns: A list of feature columns.
180      label_name: String, name of the key in label dict. Can be null if label is
181        a tensor (single headed models).
182      weight_column_name: Name of the column for weights, or None if not
183        weighted.
184      model_dir: Directory for model exports, etc.
185      config: `RunConfig` object to configure the runtime settings.
186      feature_engineering_fn: Feature engineering function. Takes features and
187        labels which are the output of `input_fn` and returns features and
188        labels which will be fed into the model.
189      logits_modifier_function: A modifier function for the logits.
190      center_bias: Whether a separate tree should be created for first fitting
191        the bias.
192      use_core_libs: Whether feature columns and loss are from the core (as
193        opposed to contrib) version of tensorflow.
194      output_leaf_index: whether to output leaf indices along with predictions
195        during inference. The leaf node indexes are available in predictions
196        dict by the key 'leaf_index'. For example, result_dict =
197        classifier.predict(...)
198        for example_prediction_result in result_dict: # access leaf index list
199          by example_prediction_result["leaf_index"] # which contains one leaf
200          index per tree
201      override_global_step_value: If after the training is done, global step
202        value must be reset to this value. This should be used to reset global
203        step to a number > number of steps used to train the current ensemble.
204        For example, the usual way is to train a number of trees and set a very
205        large number of training steps. When the training is done (number of
206        trees were trained), this parameter can be used to set the global step
207        to a large value, making it look like that number of training steps ran.
208        If None, no override of global step will happen.
209      num_quantiles: Number of quantiles to build for numeric feature values.
210    """
211    head = head_lib.regression_head(
212        label_name=label_name,
213        label_dimension=label_dimension,
214        weight_column_name=weight_column_name,
215        enable_centered_bias=False)
216    if label_dimension == 1:
217      learner_config.num_classes = 2
218    else:
219      learner_config.num_classes = label_dimension
220    super(GradientBoostedDecisionTreeRegressor, self).__init__(
221        model_fn=model.model_builder,
222        params={
223            'head': head,
224            'feature_columns': feature_columns,
225            'learner_config': learner_config,
226            'num_trees': num_trees,
227            'weight_column_name': weight_column_name,
228            'examples_per_layer': examples_per_layer,
229            'logits_modifier_function': logits_modifier_function,
230            'center_bias': center_bias,
231            'use_core_libs': use_core_libs,
232            'output_leaf_index': False,
233            'override_global_step_value': override_global_step_value,
234            'num_quantiles': num_quantiles,
235        },
236        model_dir=model_dir,
237        config=config,
238        feature_engineering_fn=feature_engineering_fn)
239
240
241class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
242  """An estimator using gradient boosted decision trees.
243
244  Useful for training with user specified `Head`.
245  """
246
247  def __init__(self,
248               learner_config,
249               examples_per_layer,
250               head,
251               num_trees=None,
252               feature_columns=None,
253               weight_column_name=None,
254               model_dir=None,
255               config=None,
256               feature_engineering_fn=None,
257               logits_modifier_function=None,
258               center_bias=True,
259               use_core_libs=False,
260               output_leaf_index=False,
261               override_global_step_value=None,
262               num_quantiles=100):
263    """Initializes a GradientBoostedDecisionTreeEstimator estimator instance.
264
265    Args:
266      learner_config: A config for the learner.
267      examples_per_layer: Number of examples to accumulate before growing a
268        layer. It can also be a function that computes the number of examples
269        based on the depth of the layer that's being built.
270      head: `Head` instance.
271      num_trees: An int, number of trees to build.
272      feature_columns: A list of feature columns.
273      weight_column_name: Name of the column for weights, or None if not
274        weighted.
275      model_dir: Directory for model exports, etc.
276      config: `RunConfig` object to configure the runtime settings.
277      feature_engineering_fn: Feature engineering function. Takes features and
278        labels which are the output of `input_fn` and returns features and
279        labels which will be fed into the model.
280      logits_modifier_function: A modifier function for the logits.
281      center_bias: Whether a separate tree should be created for first fitting
282        the bias.
283      use_core_libs: Whether feature columns and loss are from the core (as
284        opposed to contrib) version of tensorflow.
285      output_leaf_index: whether to output leaf indices along with predictions
286        during inference. The leaf node indexes are available in predictions
287        dict by the key 'leaf_index'. For example, result_dict =
288        classifier.predict(...)
289        for example_prediction_result in result_dict: # access leaf index list
290          by example_prediction_result["leaf_index"] # which contains one leaf
291          index per tree
292      override_global_step_value: If after the training is done, global step
293        value must be reset to this value. This should be used to reset global
294        step to a number > number of steps used to train the current ensemble.
295        For example, the usual way is to train a number of trees and set a very
296        large number of training steps. When the training is done (number of
297        trees were trained), this parameter can be used to set the global step
298        to a large value, making it look like that number of training steps ran.
299        If None, no override of global step will happen.
300      num_quantiles: Number of quantiles to build for numeric feature values.
301    """
302    super(GradientBoostedDecisionTreeEstimator, self).__init__(
303        model_fn=model.model_builder,
304        params={
305            'head': head,
306            'feature_columns': feature_columns,
307            'learner_config': learner_config,
308            'num_trees': num_trees,
309            'weight_column_name': weight_column_name,
310            'examples_per_layer': examples_per_layer,
311            'logits_modifier_function': logits_modifier_function,
312            'center_bias': center_bias,
313            'use_core_libs': use_core_libs,
314            'output_leaf_index': False,
315            'override_global_step_value': override_global_step_value,
316            'num_quantiles': num_quantiles,
317        },
318        model_dir=model_dir,
319        config=config,
320        feature_engineering_fn=feature_engineering_fn)
321
322
323class GradientBoostedDecisionTreeRanker(estimator.Estimator):
324  """A ranking estimator using gradient boosted decision trees."""
325
326  def __init__(self,
327               learner_config,
328               examples_per_layer,
329               head,
330               ranking_model_pair_keys,
331               num_trees=None,
332               feature_columns=None,
333               weight_column_name=None,
334               model_dir=None,
335               config=None,
336               label_keys=None,
337               feature_engineering_fn=None,
338               logits_modifier_function=None,
339               center_bias=False,
340               use_core_libs=False,
341               output_leaf_index=False,
342               override_global_step_value=None,
343               num_quantiles=100):
344    """Initializes a GradientBoostedDecisionTreeRanker instance.
345
346    This is an estimator that can be trained off the pairwise data and can be
347    used for inference on non-paired data. This is essentially LambdaMart.
348    Args:
349      learner_config: A config for the learner.
350      examples_per_layer: Number of examples to accumulate before growing a
351        layer. It can also be a function that computes the number of examples
352        based on the depth of the layer that's being built.
353      head: `Head` instance.
354      ranking_model_pair_keys: Keys to distinguish between features for left and
355        right part of the training pairs for ranking. For example, for an
356        Example with features "a.f1" and "b.f1", the keys would be ("a", "b").
357      num_trees: An int, number of trees to build.
358      feature_columns: A list of feature columns.
359      weight_column_name: Name of the column for weights, or None if not
360        weighted.
361      model_dir: Directory for model exports, etc.
362      config: `RunConfig` object to configure the runtime settings.
363      label_keys: Optional list of strings with size `[n_classes]` defining the
364        label vocabulary. Only supported for `n_classes` > 2.
365      feature_engineering_fn: Feature engineering function. Takes features and
366        labels which are the output of `input_fn` and returns features and
367        labels which will be fed into the model.
368      logits_modifier_function: A modifier function for the logits.
369      center_bias: Whether a separate tree should be created for first fitting
370        the bias.
371      use_core_libs: Whether feature columns and loss are from the core (as
372        opposed to contrib) version of tensorflow.
373      output_leaf_index: whether to output leaf indices along with predictions
374        during inference. The leaf node indexes are available in predictions
375        dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is
376        [batch_size, num_trees]. For example, result_iter =
377        classifier.predict(...)
378        for result_dict in result_iter: # access leaf index list by
379          result_dict["leaf_index"] # which contains one leaf index per tree
380      override_global_step_value: If after the training is done, global step
381        value must be reset to this value. This should be used to reset global
382        step to a number > number of steps used to train the current ensemble.
383        For example, the usual way is to train a number of trees and set a very
384        large number of training steps. When the training is done (number of
385        trees were trained), this parameter can be used to set the global step
386        to a large value, making it look like that number of training steps ran.
387        If None, no override of global step will happen.
388      num_quantiles: Number of quantiles to build for numeric feature values.
389
390    Raises:
391      ValueError: If learner_config is not valid.
392    """
393    super(GradientBoostedDecisionTreeRanker, self).__init__(
394        model_fn=model.ranking_model_builder,
395        params={
396            'head': head,
397            'n_classes': 2,
398            'feature_columns': feature_columns,
399            'learner_config': learner_config,
400            'num_trees': num_trees,
401            'weight_column_name': weight_column_name,
402            'examples_per_layer': examples_per_layer,
403            'center_bias': center_bias,
404            'logits_modifier_function': logits_modifier_function,
405            'use_core_libs': use_core_libs,
406            'output_leaf_index': output_leaf_index,
407            'ranking_model_pair_keys': ranking_model_pair_keys,
408            'override_global_step_value': override_global_step_value,
409            'num_quantiles': num_quantiles,
410        },
411        model_dir=model_dir,
412        config=config,
413        feature_engineering_fn=feature_engineering_fn)
414
415
416# When using this estimator, make sure to regularize the hessian (at least l2,
417# min_node_weight)!
418# TODO(nponomareva): extend to take multiple quantiles in one go.
419class GradientBoostedDecisionTreeQuantileRegressor(estimator.Estimator):
420  """An estimator that does quantile regression and returns quantile estimates."""
421
422  def __init__(self,
423               learner_config,
424               examples_per_layer,
425               quantiles,
426               label_dimension=1,
427               num_trees=None,
428               feature_columns=None,
429               weight_column_name=None,
430               model_dir=None,
431               config=None,
432               feature_engineering_fn=None,
433               logits_modifier_function=None,
434               center_bias=True,
435               use_core_libs=False,
436               output_leaf_index=False,
437               override_global_step_value=None,
438               num_quantiles=100):
439    """Initializes a GradientBoostedDecisionTreeQuantileRegressor instance.
440
441    Args:
442      learner_config: A config for the learner.
443      examples_per_layer: Number of examples to accumulate before growing a
444        layer. It can also be a function that computes the number of examples
445        based on the depth of the layer that's being built.
446      quantiles: a list of quantiles for the loss, each between 0 and 1.
447      label_dimension: Dimension of regression label. This is the size of the
448        last dimension of the labels `Tensor` (typically, this has shape
449        `[batch_size, label_dimension]`). When label_dimension>1, it is
450        recommended to use multiclass strategy diagonal hessian or full hessian.
451      num_trees: An int, number of trees to build.
452      feature_columns: A list of feature columns.
453      weight_column_name: Name of the column for weights, or None if not
454        weighted.
455      model_dir: Directory for model exports, etc.
456      config: `RunConfig` object to configure the runtime settings.
457      feature_engineering_fn: Feature engineering function. Takes features and
458        labels which are the output of `input_fn` and returns features and
459        labels which will be fed into the model.
460      logits_modifier_function: A modifier function for the logits.
461      center_bias: Whether a separate tree should be created for first fitting
462        the bias.
463      use_core_libs: Whether feature columns and loss are from the core (as
464        opposed to contrib) version of tensorflow.
465      output_leaf_index: whether to output leaf indices along with predictions
466        during inference. The leaf node indexes are available in predictions
467        dict by the key 'leaf_index'. For example, result_dict =
468        classifier.predict(...)
469        for example_prediction_result in result_dict: # access leaf index list
470          by example_prediction_result["leaf_index"] # which contains one leaf
471          index per tree
472      override_global_step_value: If after the training is done, global step
473        value must be reset to this value. This should be used to reset global
474        step to a number > number of steps used to train the current ensemble.
475        For example, the usual way is to train a number of trees and set a very
476        large number of training steps. When the training is done (number of
477        trees were trained), this parameter can be used to set the global step
478        to a large value, making it look like that number of training steps ran.
479        If None, no override of global step will happen.
480      num_quantiles: Number of quantiles to build for numeric feature values.
481    """
482
483    if len(quantiles) > 1:
484      raise ValueError('For now, just one quantile per estimator is supported')
485
486    def _quantile_regression_head(quantile):
487      # Use quantile regression.
488      head = custom_loss_head.CustomLossHead(
489          loss_fn=functools.partial(
490              losses.per_example_quantile_regression_loss, quantile=quantile),
491          link_fn=array_ops.identity,
492          logit_dimension=label_dimension)
493      return head
494
495    learner_config.num_classes = max(2, label_dimension)
496
497    super(GradientBoostedDecisionTreeQuantileRegressor, self).__init__(
498        model_fn=model.model_builder,
499        params={
500            'head': _quantile_regression_head(quantiles[0]),
501            'feature_columns': feature_columns,
502            'learner_config': learner_config,
503            'num_trees': num_trees,
504            'weight_column_name': weight_column_name,
505            'examples_per_layer': examples_per_layer,
506            'logits_modifier_function': logits_modifier_function,
507            'center_bias': center_bias,
508            'use_core_libs': use_core_libs,
509            'output_leaf_index': False,
510            'override_global_step_value': override_global_step_value,
511            'num_quantiles': num_quantiles,
512        },
513        model_dir=model_dir,
514        config=config,
515        feature_engineering_fn=feature_engineering_fn)
516
517
518# ================== New Estimator interface===================================
519# The estimators below use new core Estimator interface and must be used with
520# new feature columns and heads.
521
522
523# For multiclass classification, use the following head since it uses loss
524# that is twice differentiable.
525def core_multiclass_head(
526    n_classes,
527    weight_column=None,
528    loss_reduction=core_losses.Reduction.SUM_OVER_NONZERO_WEIGHTS):
529  """Core head for multiclass problems."""
530
531  def loss_fn(labels, logits):
532    result = losses.per_example_maxent_loss(
533        # Don't pass the weights: head already multiplies by them.
534        labels=labels, logits=logits, weights=None, num_classes=n_classes)
535    return result[0]
536
537  # pylint:disable=protected-access
538  head_fn = core_head_lib._multi_class_head_with_softmax_cross_entropy_loss(
539      n_classes=n_classes,
540      loss_fn=loss_fn,
541      loss_reduction=loss_reduction,
542      weight_column=weight_column)
543  # pylint:enable=protected-access
544
545  return head_fn
546
547
548# For quantile regression, use this head with Core..Estimator, or use
549# Core..QuantileRegressor directly,
550def core_quantile_regression_head(
551    quantiles,
552    label_dimension=1,
553    weight_column=None,
554    loss_reduction=core_losses.Reduction.SUM_OVER_NONZERO_WEIGHTS):
555  """Core head for quantile regression problems."""
556
557  def loss_fn(labels, logits):
558    result = losses.per_example_quantile_regression_loss(
559        labels=labels,
560        predictions=logits,
561        # Don't pass the weights: head already multiplies by them.
562        weights=None,
563        quantile=quantiles)
564    return result[0]
565
566  # pylint:disable=protected-access
567  head_fn = core_head_lib._regression_head(
568      label_dimension=label_dimension,
569      loss_fn=loss_fn,
570      loss_reduction=loss_reduction,
571      weight_column=weight_column)
572  # pylint:enable=protected-access
573  return head_fn
574
575
576class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator):
577  """An estimator using gradient boosted decision trees.
578
579  Useful for training with user specified `Head`.
580  """
581
582  def __init__(self,
583               learner_config,
584               examples_per_layer,
585               head,
586               num_trees=None,
587               feature_columns=None,
588               weight_column_name=None,
589               model_dir=None,
590               config=None,
591               label_keys=None,
592               feature_engineering_fn=None,
593               logits_modifier_function=None,
594               center_bias=True,
595               output_leaf_index=False,
596               num_quantiles=100):
597    """Initializes a core version of GradientBoostedDecisionTreeEstimator.
598
599    Args:
600      learner_config: A config for the learner.
601      examples_per_layer: Number of examples to accumulate before growing a
602        layer. It can also be a function that computes the number of examples
603        based on the depth of the layer that's being built.
604      head: `Head` instance.
605      num_trees: An int, number of trees to build.
606      feature_columns: A list of feature columns.
607      weight_column_name: Name of the column for weights, or None if not
608        weighted.
609      model_dir: Directory for model exports, etc.
610      config: `RunConfig` object to configure the runtime settings.
611      label_keys: Optional list of strings with size `[n_classes]` defining the
612        label vocabulary. Only supported for `n_classes` > 2.
613      feature_engineering_fn: Feature engineering function. Takes features and
614        labels which are the output of `input_fn` and returns features and
615        labels which will be fed into the model.
616      logits_modifier_function: A modifier function for the logits.
617      center_bias: Whether a separate tree should be created for first fitting
618        the bias.
619      output_leaf_index: whether to output leaf indices along with predictions
620        during inference. The leaf node indexes are available in predictions
621        dict by the key 'leaf_index'. For example, result_dict =
622        classifier.predict(...)
623        for example_prediction_result in result_dict: # access leaf index list
624          by example_prediction_result["leaf_index"] # which contains one leaf
625          index per tree
626      num_quantiles: Number of quantiles to build for numeric feature values.
627    """
628
629    def _model_fn(features, labels, mode, config):
630      return model.model_builder(
631          features=features,
632          labels=labels,
633          mode=mode,
634          config=config,
635          params={
636              'head': head,
637              'feature_columns': feature_columns,
638              'learner_config': learner_config,
639              'num_trees': num_trees,
640              'weight_column_name': weight_column_name,
641              'examples_per_layer': examples_per_layer,
642              'center_bias': center_bias,
643              'logits_modifier_function': logits_modifier_function,
644              'use_core_libs': True,
645              'output_leaf_index': output_leaf_index,
646              'override_global_step_value': None,
647              'num_quantiles': num_quantiles,
648          },
649          output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC)
650
651    super(CoreGradientBoostedDecisionTreeEstimator, self).__init__(
652        model_fn=_model_fn, model_dir=model_dir, config=config)
653
654
655class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator):
656  """A ranking estimator using gradient boosted decision trees."""
657
658  def __init__(self,
659               learner_config,
660               examples_per_layer,
661               head,
662               ranking_model_pair_keys,
663               num_trees=None,
664               feature_columns=None,
665               weight_column_name=None,
666               model_dir=None,
667               config=None,
668               label_keys=None,
669               logits_modifier_function=None,
670               center_bias=False,
671               output_leaf_index=False,
672               num_quantiles=100):
673    """Initializes a GradientBoostedDecisionTreeRanker instance.
674
675    This is an estimator that can be trained off the pairwise data and can be
676    used for inference on non-paired data. This is essentially LambdaMart.
677    Args:
678      learner_config: A config for the learner.
679      examples_per_layer: Number of examples to accumulate before growing a
680        layer. It can also be a function that computes the number of examples
681        based on the depth of the layer that's being built.
682      head: `Head` instance.
683      ranking_model_pair_keys: Keys to distinguish between features for left and
684        right part of the training pairs for ranking. For example, for an
685        Example with features "a.f1" and "b.f1", the keys would be ("a", "b").
686      num_trees: An int, number of trees to build.
687      feature_columns: A list of feature columns.
688      weight_column_name: Name of the column for weights, or None if not
689        weighted.
690      model_dir: Directory for model exports, etc.
691      config: `RunConfig` object to configure the runtime settings.
692      label_keys: Optional list of strings with size `[n_classes]` defining the
693        label vocabulary. Only supported for `n_classes` > 2.
694      logits_modifier_function: A modifier function for the logits.
695      center_bias: Whether a separate tree should be created for first fitting
696        the bias.
697      output_leaf_index: whether to output leaf indices along with predictions
698        during inference. The leaf node indexes are available in predictions
699        dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is
700        [batch_size, num_trees]. For example, result_iter =
701        classifier.predict(...)
702        for result_dict in result_iter: # access leaf index list by
703          result_dict["leaf_index"] # which contains one leaf index per tree
704      num_quantiles: Number of quantiles to build for numeric feature values.
705
706    Raises:
707      ValueError: If learner_config is not valid.
708    """
709
710    def _model_fn(features, labels, mode, config):
711      return model.ranking_model_builder(
712          features=features,
713          labels=labels,
714          mode=mode,
715          config=config,
716          params={
717              'head': head,
718              'n_classes': 2,
719              'feature_columns': feature_columns,
720              'learner_config': learner_config,
721              'num_trees': num_trees,
722              'weight_column_name': weight_column_name,
723              'examples_per_layer': examples_per_layer,
724              'center_bias': center_bias,
725              'logits_modifier_function': logits_modifier_function,
726              'use_core_libs': True,
727              'output_leaf_index': output_leaf_index,
728              'ranking_model_pair_keys': ranking_model_pair_keys,
729              'override_global_step_value': None,
730              'num_quantiles': num_quantiles,
731          },
732          output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC)
733
734    super(CoreGradientBoostedDecisionTreeRanker, self).__init__(
735        model_fn=_model_fn, model_dir=model_dir, config=config)
736
737
738# When using this estimator, make sure to regularize the hessian (at least l2,
739# min_node_weight)!
740# TODO(nponomareva): extend to take multiple quantiles in one go.
741class CoreGradientBoostedDecisionTreeQuantileRegressor(
742    core_estimator.Estimator):
743  """An estimator that does quantile regression and returns quantile estimates."""
744
745  def __init__(self,
746               learner_config,
747               examples_per_layer,
748               quantiles,
749               label_dimension=1,
750               num_trees=None,
751               feature_columns=None,
752               weight_column_name=None,
753               model_dir=None,
754               config=None,
755               label_keys=None,
756               feature_engineering_fn=None,
757               logits_modifier_function=None,
758               center_bias=True,
759               output_leaf_index=False,
760               num_quantiles=100):
761    """Initializes a core version of GradientBoostedDecisionTreeEstimator.
762
763    Args:
764      learner_config: A config for the learner.
765      examples_per_layer: Number of examples to accumulate before growing a
766        layer. It can also be a function that computes the number of examples
767        based on the depth of the layer that's being built.
768      quantiles: a list of quantiles for the loss, each between 0 and 1.
769      label_dimension: Dimension of regression label. This is the size of the
770        last dimension of the labels `Tensor` (typically, this has shape
771        `[batch_size, label_dimension]`). When label_dimension>1, it is
772        recommended to use multiclass strategy diagonal hessian or full hessian.
773      num_trees: An int, number of trees to build.
774      feature_columns: A list of feature columns.
775      weight_column_name: Name of the column for weights, or None if not
776        weighted.
777      model_dir: Directory for model exports, etc.
778      config: `RunConfig` object to configure the runtime settings.
779      label_keys: Optional list of strings with size `[n_classes]` defining the
780        label vocabulary. Only supported for `n_classes` > 2.
781      feature_engineering_fn: Feature engineering function. Takes features and
782        labels which are the output of `input_fn` and returns features and
783        labels which will be fed into the model.
784      logits_modifier_function: A modifier function for the logits.
785      center_bias: Whether a separate tree should be created for first fitting
786        the bias.
787      output_leaf_index: whether to output leaf indices along with predictions
788        during inference. The leaf node indexes are available in predictions
789        dict by the key 'leaf_index'. For example, result_dict =
790        classifier.predict(...)
791        for example_prediction_result in result_dict: # access leaf index list
792          by example_prediction_result["leaf_index"] # which contains one leaf
793          index per tree
794      num_quantiles: Number of quantiles to build for numeric feature values.
795    """
796    if len(quantiles) > 1:
797      raise ValueError('For now, just one quantile per estimator is supported')
798
799    def _model_fn(features, labels, mode, config):
800      return model.model_builder(
801          features=features,
802          labels=labels,
803          mode=mode,
804          config=config,
805          params={
806              'head':
807                  core_quantile_regression_head(
808                      quantiles[0],
809                      label_dimension=label_dimension,
810                      weight_column=weight_column_name),
811              'feature_columns':
812                  feature_columns,
813              'learner_config':
814                  learner_config,
815              'num_trees':
816                  num_trees,
817              'weight_column_name':
818                  weight_column_name,
819              'examples_per_layer':
820                  examples_per_layer,
821              'center_bias':
822                  center_bias,
823              'logits_modifier_function':
824                  logits_modifier_function,
825              'use_core_libs':
826                  True,
827              'output_leaf_index':
828                  output_leaf_index,
829              'override_global_step_value':
830                  None,
831              'num_quantiles':
832                  num_quantiles,
833          },
834          output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC)
835
836    super(CoreGradientBoostedDecisionTreeQuantileRegressor, self).__init__(
837        model_fn=_model_fn, model_dir=model_dir, config=config)
838