• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/resource_mgr.h"
19 #include "tensorflow/core/framework/shape_inference.h"
20 
21 namespace tensorflow {
22 namespace boosted_trees {
23 
24 REGISTER_RESOURCE_HANDLE_OP(DecisionTreeEnsembleResource);
25 
26 REGISTER_OP("TreeEnsembleIsInitializedOp")
27     .Input("tree_ensemble_handle: resource")
28     .Output("is_initialized: bool")
__anond0e057b60102(shape_inference::InferenceContext* c) 29     .SetShapeFn([](shape_inference::InferenceContext* c) {
30       shape_inference::ShapeHandle unused_input;
31       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
32       c->set_output(0, c->Scalar());
33       return Status::OK();
34     })
35     .Doc(R"doc(
36 Checks whether a tree ensemble has been initialized.
37 )doc");
38 
39 REGISTER_OP("CreateTreeEnsembleVariable")
40     .Input("tree_ensemble_handle: resource")
41     .Input("stamp_token: int64")
42     .Input("tree_ensemble_config: string")
__anond0e057b60202(shape_inference::InferenceContext* c) 43     .SetShapeFn([](shape_inference::InferenceContext* c) {
44       shape_inference::ShapeHandle unused_input;
45       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
46       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
47       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
48       return Status::OK();
49     })
50     .Doc(R"doc(
51 Creates a tree ensemble model and returns a handle to it.
52 
53 tree_ensemble_handle: Handle to the tree ensemble resource to be created.
54 stamp_token: Token to use as the initial value of the resource stamp.
55 tree_ensemble_config: Serialized proto of the tree ensemble.
56 )doc");
57 
58 REGISTER_OP("TreeEnsembleStampToken")
59     .Input("tree_ensemble_handle: resource")
60     .Output("stamp_token: int64")
__anond0e057b60302(shape_inference::InferenceContext* c) 61     .SetShapeFn([](shape_inference::InferenceContext* c) {
62       shape_inference::ShapeHandle unused_input;
63       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
64       c->set_output(0, c->Scalar());
65       return Status::OK();
66     })
67     .Doc(R"doc(
68 Retrieves the tree ensemble resource stamp token.
69 
70 tree_ensemble_handle: Handle to the tree ensemble.
71 stamp_token: Stamp token of the tree ensemble resource.
72 )doc");
73 
74 REGISTER_OP("TreeEnsembleSerialize")
75     .Input("tree_ensemble_handle: resource")
76     .Output("stamp_token: int64")
77     .Output("tree_ensemble_config: string")
__anond0e057b60402(shape_inference::InferenceContext* c) 78     .SetShapeFn([](shape_inference::InferenceContext* c) {
79       shape_inference::ShapeHandle unused_input;
80       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
81       c->set_output(0, c->Scalar());
82       c->set_output(1, c->Scalar());
83       return Status::OK();
84     })
85     .Doc(R"doc(
86 Serializes the tree ensemble to a proto.
87 
88 tree_ensemble_handle: Handle to the tree ensemble.
89 stamp_token: Stamp token of the tree ensemble resource.
90 tree_ensemble_config: Serialized proto of the ensemble.
91 )doc");
92 
93 REGISTER_OP("TreeEnsembleDeserialize")
94     .Input("tree_ensemble_handle: resource")
95     .Input("stamp_token: int64")
96     .Input("tree_ensemble_config: string")
__anond0e057b60502(shape_inference::InferenceContext* c) 97     .SetShapeFn([](shape_inference::InferenceContext* c) {
98       shape_inference::ShapeHandle unused_input;
99       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
100       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
101       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
102       return Status::OK();
103     })
104     .Doc(R"doc(
105 Deserializes a serialized tree ensemble config and replaces current tree
106 ensemble.
107 
108 tree_ensemble_handle: Handle to the tree ensemble.
109 stamp_token: Token to use as the new value of the resource stamp.
110 tree_ensemble_config: Serialized proto of the ensemble.
111 )doc");
112 
113 REGISTER_OP("TreeEnsembleUsedHandlers")
114     .Attr("num_all_handlers: int >= 0")
115     .Input("tree_ensemble_handle: resource")
116     .Input("stamp_token: int64")
117     .Output("num_used_handlers: int64")
118     .Output("used_handlers_mask: bool")
__anond0e057b60602(shape_inference::InferenceContext* c) 119     .SetShapeFn([](shape_inference::InferenceContext* c) {
120       shape_inference::ShapeHandle unused_input;
121       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
122       c->set_output(0, c->Scalar());
123       int num_all_handlers;
124       c->GetAttr("num_all_handlers", &num_all_handlers).IgnoreError();
125       c->set_output(1, {c->Vector(num_all_handlers)});
126 
127       return Status::OK();
128     })
129     .Doc(R"doc(
130 Returns the mask of used handlers along with the number of non-zero elements in
131 this mask. Used in feature selection.
132 
133 tree_ensemble_handle: Handle to the tree ensemble.
134 stamp_token: Token to use as the new value of the resource stamp.
135 num_used_handlers: number of feature column handlers used in the model.
136 used_handlers_mask: A boolean vector of showing which handlers are used in the
137                     model.
138 )doc");
139 
140 }  // namespace boosted_trees
141 }  // namespace tensorflow
142