1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 #include "tensorflow/core/framework/common_shape_fns.h" 16 #include "tensorflow/core/framework/op.h" 17 #include "tensorflow/core/framework/shape_inference.h" 18 19 namespace tensorflow { 20 namespace boosted_trees { 21 22 REGISTER_OP("CenterTreeEnsembleBias") 23 .Attr("learner_config: string") 24 .Attr("centering_epsilon: float = 0.01") 25 .Input("tree_ensemble_handle: resource") 26 .Input("stamp_token: int64") 27 .Input("next_stamp_token: int64") 28 .Input("delta_updates: float") 29 .Output("continue_centering: bool") __anond8bb5a410102(shape_inference::InferenceContext* c) 30 .SetShapeFn([](shape_inference::InferenceContext* c) { 31 shape_inference::ShapeHandle unused_input; 32 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); 33 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input)); 34 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input)); 35 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &unused_input)); 36 c->set_output(0, c->Scalar()); 37 return Status::OK(); 38 }) 39 .Doc(R"doc( 40 Centers the tree ensemble bias before adding trees based on feature splits. 41 42 learner_config: Config for the learner of type LearnerConfig proto. 43 tree_ensemble_handle: Handle to the ensemble variable. 44 stamp_token: Stamp token for validating operation consistency. 45 next_stamp_token: Stamp token to be used for the next iteration. 46 delta_updates: Rank 1 Tensor containing delta updates per bias dimension. 47 continue_centering: Scalar indicating whether more centering is needed. 48 )doc"); 49 50 REGISTER_OP("GrowTreeEnsemble") 51 .Attr("learner_config: string") 52 .Attr("num_handlers: int >= 0") 53 .Attr("center_bias: bool") 54 .Input("tree_ensemble_handle: resource") 55 .Input("stamp_token: int64") 56 .Input("next_stamp_token: int64") 57 .Input("learning_rate: float") 58 .Input("dropout_seed: int64") 59 .Input("max_tree_depth: int32") 60 .Input("weak_learner_type: int32") 61 .Input("partition_ids: num_handlers * int32") 62 .Input("gains: num_handlers * float") 63 .Input("splits: num_handlers * string") __anond8bb5a410202(shape_inference::InferenceContext* c) 64 .SetShapeFn([](shape_inference::InferenceContext* c) { 65 shape_inference::ShapeHandle unused_input; 66 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); 67 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input)); 68 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input)); 69 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_input)); 70 // Dropout seed. 71 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_input)); 72 // Maximum tree depth. 73 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused_input)); 74 return Status::OK(); 75 }) 76 .Doc(R"doc( 77 Grows the tree ensemble by either adding a layer to the last tree being grown 78 or by starting a new tree. 79 80 learner_config: Config for the learner of type LearnerConfig proto. 81 num_handlers: Number of handlers generating candidates. 82 tree_ensemble_handle: Handle to the ensemble variable. 83 stamp_token: Stamp token for validating operation consistency. 84 next_stamp_token: Stamp token to be used for the next iteration. 85 learning_rate: Scalar learning rate. 86 weak_learner_type: The type of weak learner to use. 87 partition_ids: List of Rank 1 Tensors containing partition Id per candidate. 88 gains: List of Rank 1 Tensors containing gains per candidate. 89 splits: List of Rank 1 Tensors containing serialized SplitInfo protos per candidate. 90 )doc"); 91 92 REGISTER_OP("TreeEnsembleStats") 93 .Input("tree_ensemble_handle: resource") 94 .Input("stamp_token: int64") 95 .Output("num_trees: int64") 96 .Output("num_layers: int64") 97 .Output("active_tree: int64") 98 .Output("active_layer: int64") 99 .Output("attempted_trees: int64") 100 .Output("attempted_layers: int64") __anond8bb5a410302(shape_inference::InferenceContext* c) 101 .SetShapeFn([](shape_inference::InferenceContext* c) { 102 shape_inference::ShapeHandle unused_input; 103 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); 104 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input)); 105 c->set_output(0, c->Scalar()); 106 c->set_output(1, c->Scalar()); 107 c->set_output(2, c->Scalar()); 108 c->set_output(3, c->Scalar()); 109 c->set_output(4, c->Scalar()); 110 c->set_output(5, c->Scalar()); 111 return Status::OK(); 112 }) 113 .Doc(R"doc( 114 Retrieves stats related to the tree ensemble. 115 116 tree_ensemble_handle: Handle to the ensemble variable. 117 stamp_token: Stamp token for validating operation consistency. 118 num_trees: Scalar indicating the number of finalized trees in the ensemble. 119 num_layers: Scalar indicating the number of layers in the ensemble. 120 active_tree: Scalar indicating the active tree being trained. 121 active_layer: Scalar indicating the active layer being trained. 122 )doc"); 123 124 } // namespace boosted_trees 125 } // namespace tensorflow 126