1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 16 #include "tensorflow/core/framework/common_shape_fns.h" 17 #include "tensorflow/core/framework/op.h" 18 #include "tensorflow/core/framework/resource_mgr.h" 19 #include "tensorflow/core/framework/shape_inference.h" 20 21 namespace tensorflow { 22 namespace boosted_trees { 23 24 REGISTER_RESOURCE_HANDLE_OP(DecisionTreeEnsembleResource); 25 26 REGISTER_OP("TreeEnsembleIsInitializedOp") 27 .Input("tree_ensemble_handle: resource") 28 .Output("is_initialized: bool") __anond0e057b60102(shape_inference::InferenceContext* c) 29 .SetShapeFn([](shape_inference::InferenceContext* c) { 30 shape_inference::ShapeHandle unused_input; 31 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); 32 c->set_output(0, c->Scalar()); 33 return Status::OK(); 34 }) 35 .Doc(R"doc( 36 Checks whether a tree ensemble has been initialized. 37 )doc"); 38 39 REGISTER_OP("CreateTreeEnsembleVariable") 40 .Input("tree_ensemble_handle: resource") 41 .Input("stamp_token: int64") 42 .Input("tree_ensemble_config: string") __anond0e057b60202(shape_inference::InferenceContext* c) 43 .SetShapeFn([](shape_inference::InferenceContext* c) { 44 shape_inference::ShapeHandle unused_input; 45 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); 46 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input)); 47 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input)); 48 return Status::OK(); 49 }) 50 .Doc(R"doc( 51 Creates a tree ensemble model and returns a handle to it. 52 53 tree_ensemble_handle: Handle to the tree ensemble resource to be created. 54 stamp_token: Token to use as the initial value of the resource stamp. 55 tree_ensemble_config: Serialized proto of the tree ensemble. 56 )doc"); 57 58 REGISTER_OP("TreeEnsembleStampToken") 59 .Input("tree_ensemble_handle: resource") 60 .Output("stamp_token: int64") __anond0e057b60302(shape_inference::InferenceContext* c) 61 .SetShapeFn([](shape_inference::InferenceContext* c) { 62 shape_inference::ShapeHandle unused_input; 63 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); 64 c->set_output(0, c->Scalar()); 65 return Status::OK(); 66 }) 67 .Doc(R"doc( 68 Retrieves the tree ensemble resource stamp token. 69 70 tree_ensemble_handle: Handle to the tree ensemble. 71 stamp_token: Stamp token of the tree ensemble resource. 72 )doc"); 73 74 REGISTER_OP("TreeEnsembleSerialize") 75 .Input("tree_ensemble_handle: resource") 76 .Output("stamp_token: int64") 77 .Output("tree_ensemble_config: string") __anond0e057b60402(shape_inference::InferenceContext* c) 78 .SetShapeFn([](shape_inference::InferenceContext* c) { 79 shape_inference::ShapeHandle unused_input; 80 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); 81 c->set_output(0, c->Scalar()); 82 c->set_output(1, c->Scalar()); 83 return Status::OK(); 84 }) 85 .Doc(R"doc( 86 Serializes the tree ensemble to a proto. 87 88 tree_ensemble_handle: Handle to the tree ensemble. 89 stamp_token: Stamp token of the tree ensemble resource. 90 tree_ensemble_config: Serialized proto of the ensemble. 91 )doc"); 92 93 REGISTER_OP("TreeEnsembleDeserialize") 94 .Input("tree_ensemble_handle: resource") 95 .Input("stamp_token: int64") 96 .Input("tree_ensemble_config: string") __anond0e057b60502(shape_inference::InferenceContext* c) 97 .SetShapeFn([](shape_inference::InferenceContext* c) { 98 shape_inference::ShapeHandle unused_input; 99 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); 100 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input)); 101 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input)); 102 return Status::OK(); 103 }) 104 .Doc(R"doc( 105 Deserializes a serialized tree ensemble config and replaces current tree 106 ensemble. 107 108 tree_ensemble_handle: Handle to the tree ensemble. 109 stamp_token: Token to use as the new value of the resource stamp. 110 tree_ensemble_config: Serialized proto of the ensemble. 111 )doc"); 112 113 REGISTER_OP("TreeEnsembleUsedHandlers") 114 .Attr("num_all_handlers: int >= 0") 115 .Input("tree_ensemble_handle: resource") 116 .Input("stamp_token: int64") 117 .Output("num_used_handlers: int64") 118 .Output("used_handlers_mask: bool") __anond0e057b60602(shape_inference::InferenceContext* c) 119 .SetShapeFn([](shape_inference::InferenceContext* c) { 120 shape_inference::ShapeHandle unused_input; 121 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); 122 c->set_output(0, c->Scalar()); 123 int num_all_handlers; 124 c->GetAttr("num_all_handlers", &num_all_handlers).IgnoreError(); 125 c->set_output(1, {c->Vector(num_all_handlers)}); 126 127 return Status::OK(); 128 }) 129 .Doc(R"doc( 130 Returns the mask of used handlers along with the number of non-zero elements in 131 this mask. Used in feature selection. 132 133 tree_ensemble_handle: Handle to the tree ensemble. 134 stamp_token: Token to use as the new value of the resource stamp. 135 num_used_handlers: number of feature column handlers used in the model. 136 used_handlers_mask: A boolean vector of showing which handlers are used in the 137 model. 138 )doc"); 139 140 } // namespace boosted_trees 141 } // namespace tensorflow 142