• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include "tensorflow/core/framework/common_shape_fns.h"
16 #include "tensorflow/core/framework/op.h"
17 #include "tensorflow/core/framework/shape_inference.h"
18 
19 namespace tensorflow {
20 namespace boosted_trees {
21 
22 REGISTER_OP("CenterTreeEnsembleBias")
23     .Attr("learner_config: string")
24     .Attr("centering_epsilon: float = 0.01")
25     .Input("tree_ensemble_handle: resource")
26     .Input("stamp_token: int64")
27     .Input("next_stamp_token: int64")
28     .Input("delta_updates: float")
29     .Output("continue_centering: bool")
__anond8bb5a410102(shape_inference::InferenceContext* c) 30     .SetShapeFn([](shape_inference::InferenceContext* c) {
31       shape_inference::ShapeHandle unused_input;
32       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
33       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
34       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
35       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &unused_input));
36       c->set_output(0, c->Scalar());
37       return Status::OK();
38     })
39     .Doc(R"doc(
40 Centers the tree ensemble bias before adding trees based on feature splits.
41 
42 learner_config: Config for the learner of type LearnerConfig proto.
43 tree_ensemble_handle: Handle to the ensemble variable.
44 stamp_token: Stamp token for validating operation consistency.
45 next_stamp_token: Stamp token to be used for the next iteration.
46 delta_updates: Rank 1 Tensor containing delta updates per bias dimension.
47 continue_centering: Scalar indicating whether more centering is needed.
48 )doc");
49 
50 REGISTER_OP("GrowTreeEnsemble")
51     .Attr("learner_config: string")
52     .Attr("num_handlers: int >= 0")
53     .Attr("center_bias: bool")
54     .Input("tree_ensemble_handle: resource")
55     .Input("stamp_token: int64")
56     .Input("next_stamp_token: int64")
57     .Input("learning_rate: float")
58     .Input("dropout_seed: int64")
59     .Input("max_tree_depth: int32")
60     .Input("weak_learner_type: int32")
61     .Input("partition_ids: num_handlers * int32")
62     .Input("gains: num_handlers * float")
63     .Input("splits: num_handlers * string")
__anond8bb5a410202(shape_inference::InferenceContext* c) 64     .SetShapeFn([](shape_inference::InferenceContext* c) {
65       shape_inference::ShapeHandle unused_input;
66       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
67       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
68       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
69       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_input));
70       // Dropout seed.
71       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_input));
72       // Maximum tree depth.
73       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused_input));
74       return Status::OK();
75     })
76     .Doc(R"doc(
77 Grows the tree ensemble by either adding a layer to the last tree being grown
78 or by starting a new tree.
79 
80 learner_config: Config for the learner of type LearnerConfig proto.
81 num_handlers: Number of handlers generating candidates.
82 tree_ensemble_handle: Handle to the ensemble variable.
83 stamp_token: Stamp token for validating operation consistency.
84 next_stamp_token: Stamp token to be used for the next iteration.
85 learning_rate: Scalar learning rate.
86 weak_learner_type: The type of weak learner to use.
87 partition_ids: List of Rank 1 Tensors containing partition Id per candidate.
88 gains: List of Rank 1 Tensors containing gains per candidate.
89 splits: List of Rank 1 Tensors containing serialized SplitInfo protos per candidate.
90 )doc");
91 
92 REGISTER_OP("TreeEnsembleStats")
93     .Input("tree_ensemble_handle: resource")
94     .Input("stamp_token: int64")
95     .Output("num_trees: int64")
96     .Output("num_layers: int64")
97     .Output("active_tree: int64")
98     .Output("active_layer: int64")
99     .Output("attempted_trees: int64")
100     .Output("attempted_layers: int64")
__anond8bb5a410302(shape_inference::InferenceContext* c) 101     .SetShapeFn([](shape_inference::InferenceContext* c) {
102       shape_inference::ShapeHandle unused_input;
103       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
104       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
105       c->set_output(0, c->Scalar());
106       c->set_output(1, c->Scalar());
107       c->set_output(2, c->Scalar());
108       c->set_output(3, c->Scalar());
109       c->set_output(4, c->Scalar());
110       c->set_output(5, c->Scalar());
111       return Status::OK();
112     })
113     .Doc(R"doc(
114 Retrieves stats related to the tree ensemble.
115 
116 tree_ensemble_handle: Handle to the ensemble variable.
117 stamp_token: Stamp token for validating operation consistency.
118 num_trees: Scalar indicating the number of finalized trees in the ensemble.
119 num_layers: Scalar indicating the number of layers in the ensemble.
120 active_tree: Scalar indicating the active tree being trained.
121 active_layer: Scalar indicating the active layer being trained.
122 )doc");
123 
124 }  // namespace boosted_trees
125 }  // namespace tensorflow
126