• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <string>
18 
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/framework/resource_mgr.h"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/framework/tensor_shape.h"
23 #include "tensorflow/core/framework/tensor_types.h"
24 #include "tensorflow/core/kernels/boosted_trees/resources.h"
25 #include "tensorflow/core/lib/core/refcount.h"
26 
27 namespace tensorflow {
28 
29 REGISTER_RESOURCE_HANDLE_KERNEL(BoostedTreesEnsembleResource);
30 
31 REGISTER_KERNEL_BUILDER(
32     Name("IsBoostedTreesEnsembleInitialized").Device(DEVICE_CPU),
33     IsResourceInitialized<BoostedTreesEnsembleResource>);
34 
35 // Creates a tree ensemble resource.
36 class BoostedTreesCreateEnsembleOp : public OpKernel {
37  public:
BoostedTreesCreateEnsembleOp(OpKernelConstruction * context)38   explicit BoostedTreesCreateEnsembleOp(OpKernelConstruction* context)
39       : OpKernel(context) {}
40 
Compute(OpKernelContext * context)41   void Compute(OpKernelContext* context) override {
42     // Get the stamp token.
43     const Tensor* stamp_token_t;
44     OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
45     int64_t stamp_token = stamp_token_t->scalar<int64>()();
46 
47     // Get the tree ensemble proto.
48     const Tensor* tree_ensemble_serialized_t;
49     OP_REQUIRES_OK(context, context->input("tree_ensemble_serialized",
50                                            &tree_ensemble_serialized_t));
51     std::unique_ptr<BoostedTreesEnsembleResource> result(
52         new BoostedTreesEnsembleResource());
53     if (!result->InitFromSerialized(
54             tree_ensemble_serialized_t->scalar<tstring>()(), stamp_token)) {
55       result->Unref();
56       result.release();  // Needed due to the `->Unref` above, to prevent UAF
57       OP_REQUIRES(
58           context, false,
59           errors::InvalidArgument("Unable to parse tree ensemble proto."));
60     }
61 
62     // Only create one, if one does not exist already. Report status for all
63     // other exceptions.
64     auto status =
65         CreateResource(context, HandleFromInput(context, 0), result.release());
66     if (status.code() != tensorflow::error::ALREADY_EXISTS) {
67       OP_REQUIRES_OK(context, status);
68     }
69   }
70 };
71 
72 REGISTER_KERNEL_BUILDER(Name("BoostedTreesCreateEnsemble").Device(DEVICE_CPU),
73                         BoostedTreesCreateEnsembleOp);
74 
75 // Op for retrieving some model states (needed for training).
76 class BoostedTreesGetEnsembleStatesOp : public OpKernel {
77  public:
BoostedTreesGetEnsembleStatesOp(OpKernelConstruction * context)78   explicit BoostedTreesGetEnsembleStatesOp(OpKernelConstruction* context)
79       : OpKernel(context) {}
80 
Compute(OpKernelContext * context)81   void Compute(OpKernelContext* context) override {
82     // Looks up the resource.
83     core::RefCountPtr<BoostedTreesEnsembleResource> tree_ensemble_resource;
84     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
85                                            &tree_ensemble_resource));
86     tf_shared_lock l(*tree_ensemble_resource->get_mutex());
87 
88     // Sets the outputs.
89     const int num_trees = tree_ensemble_resource->num_trees();
90     const int num_finalized_trees =
91         (num_trees <= 0 ||
92          tree_ensemble_resource->IsTreeFinalized(num_trees - 1))
93             ? num_trees
94             : num_trees - 1;
95     const int num_attempted_layers =
96         tree_ensemble_resource->GetNumLayersAttempted();
97 
98     // growing_metadata
99     Tensor* output_stamp_token_t = nullptr;
100     Tensor* output_num_trees_t = nullptr;
101     Tensor* output_num_finalized_trees_t = nullptr;
102     Tensor* output_num_attempted_layers_t = nullptr;
103     Tensor* output_last_layer_nodes_range_t = nullptr;
104 
105     OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
106                                                      &output_stamp_token_t));
107     OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape(),
108                                                      &output_num_trees_t));
109     OP_REQUIRES_OK(context,
110                    context->allocate_output(2, TensorShape(),
111                                             &output_num_finalized_trees_t));
112     OP_REQUIRES_OK(context,
113                    context->allocate_output(3, TensorShape(),
114                                             &output_num_attempted_layers_t));
115     OP_REQUIRES_OK(context, context->allocate_output(
116                                 4, {2}, &output_last_layer_nodes_range_t));
117 
118     output_stamp_token_t->scalar<int64>()() = tree_ensemble_resource->stamp();
119     output_num_trees_t->scalar<int32>()() = num_trees;
120     output_num_finalized_trees_t->scalar<int32>()() = num_finalized_trees;
121     output_num_attempted_layers_t->scalar<int32>()() = num_attempted_layers;
122 
123     int32_t range_start;
124     int32_t range_end;
125     tree_ensemble_resource->GetLastLayerNodesRange(&range_start, &range_end);
126 
127     output_last_layer_nodes_range_t->vec<int32>()(0) = range_start;
128     // For a completely empty ensemble, this will be 0. To make it a valid range
129     // we add this max cond.
130     output_last_layer_nodes_range_t->vec<int32>()(1) = std::max(1, range_end);
131   }
132 };
133 
134 REGISTER_KERNEL_BUILDER(
135     Name("BoostedTreesGetEnsembleStates").Device(DEVICE_CPU),
136     BoostedTreesGetEnsembleStatesOp);
137 
138 // Op for serializing a model.
139 class BoostedTreesSerializeEnsembleOp : public OpKernel {
140  public:
BoostedTreesSerializeEnsembleOp(OpKernelConstruction * context)141   explicit BoostedTreesSerializeEnsembleOp(OpKernelConstruction* context)
142       : OpKernel(context) {}
143 
Compute(OpKernelContext * context)144   void Compute(OpKernelContext* context) override {
145     core::RefCountPtr<BoostedTreesEnsembleResource> tree_ensemble_resource;
146     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
147                                            &tree_ensemble_resource));
148     tf_shared_lock l(*tree_ensemble_resource->get_mutex());
149     Tensor* output_stamp_token_t = nullptr;
150     OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
151                                                      &output_stamp_token_t));
152     output_stamp_token_t->scalar<int64>()() = tree_ensemble_resource->stamp();
153     Tensor* output_proto_t = nullptr;
154     OP_REQUIRES_OK(context,
155                    context->allocate_output(1, TensorShape(), &output_proto_t));
156     output_proto_t->scalar<tstring>()() =
157         tree_ensemble_resource->SerializeAsString();
158   }
159 };
160 
161 REGISTER_KERNEL_BUILDER(
162     Name("BoostedTreesSerializeEnsemble").Device(DEVICE_CPU),
163     BoostedTreesSerializeEnsembleOp);
164 
165 // Op for deserializing a tree ensemble variable from a checkpoint.
166 class BoostedTreesDeserializeEnsembleOp : public OpKernel {
167  public:
BoostedTreesDeserializeEnsembleOp(OpKernelConstruction * context)168   explicit BoostedTreesDeserializeEnsembleOp(OpKernelConstruction* context)
169       : OpKernel(context) {}
170 
Compute(OpKernelContext * context)171   void Compute(OpKernelContext* context) override {
172     core::RefCountPtr<BoostedTreesEnsembleResource> tree_ensemble_resource;
173     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
174                                            &tree_ensemble_resource));
175     mutex_lock l(*tree_ensemble_resource->get_mutex());
176 
177     // Get the stamp token.
178     const Tensor* stamp_token_t;
179     OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
180     int64_t stamp_token = stamp_token_t->scalar<int64>()();
181 
182     // Get the tree ensemble proto.
183     const Tensor* tree_ensemble_serialized_t;
184     OP_REQUIRES_OK(context, context->input("tree_ensemble_serialized",
185                                            &tree_ensemble_serialized_t));
186     // Deallocate all the previous objects on the resource.
187     tree_ensemble_resource->Reset();
188     OP_REQUIRES(
189         context,
190         tree_ensemble_resource->InitFromSerialized(
191             tree_ensemble_serialized_t->scalar<tstring>()(), stamp_token),
192         errors::InvalidArgument("Unable to parse tree ensemble proto."));
193   }
194 };
195 
196 REGISTER_KERNEL_BUILDER(
197     Name("BoostedTreesDeserializeEnsemble").Device(DEVICE_CPU),
198     BoostedTreesDeserializeEnsembleOp);
199 
200 }  // namespace tensorflow
201