• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <string>
18 
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/framework/resource_mgr.h"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/framework/tensor_shape.h"
23 #include "tensorflow/core/framework/tensor_types.h"
24 #include "tensorflow/core/kernels/boosted_trees/resources.h"
25 #include "tensorflow/core/lib/core/refcount.h"
26 
27 namespace tensorflow {
28 
29 REGISTER_RESOURCE_HANDLE_KERNEL(BoostedTreesEnsembleResource);
30 
31 REGISTER_KERNEL_BUILDER(
32     Name("IsBoostedTreesEnsembleInitialized").Device(DEVICE_CPU),
33     IsResourceInitialized<BoostedTreesEnsembleResource>);
34 
35 // Creates a tree ensemble resource.
36 class BoostedTreesCreateEnsembleOp : public OpKernel {
37  public:
BoostedTreesCreateEnsembleOp(OpKernelConstruction * context)38   explicit BoostedTreesCreateEnsembleOp(OpKernelConstruction* context)
39       : OpKernel(context) {}
40 
Compute(OpKernelContext * context)41   void Compute(OpKernelContext* context) override {
42     // Get the stamp token.
43     const Tensor* stamp_token_t;
44     OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
45     int64 stamp_token = stamp_token_t->scalar<int64>()();
46 
47     // Get the tree ensemble proto.
48     const Tensor* tree_ensemble_serialized_t;
49     OP_REQUIRES_OK(context, context->input("tree_ensemble_serialized",
50                                            &tree_ensemble_serialized_t));
51     std::unique_ptr<BoostedTreesEnsembleResource> result(
52         new BoostedTreesEnsembleResource());
53     if (!result->InitFromSerialized(
54             tree_ensemble_serialized_t->scalar<tstring>()(), stamp_token)) {
55       result->Unref();
56       OP_REQUIRES(
57           context, false,
58           errors::InvalidArgument("Unable to parse tree ensemble proto."));
59     }
60 
61     // Only create one, if one does not exist already. Report status for all
62     // other exceptions.
63     auto status =
64         CreateResource(context, HandleFromInput(context, 0), result.release());
65     if (status.code() != tensorflow::error::ALREADY_EXISTS) {
66       OP_REQUIRES_OK(context, status);
67     }
68   }
69 };
70 
71 REGISTER_KERNEL_BUILDER(Name("BoostedTreesCreateEnsemble").Device(DEVICE_CPU),
72                         BoostedTreesCreateEnsembleOp);
73 
74 // Op for retrieving some model states (needed for training).
75 class BoostedTreesGetEnsembleStatesOp : public OpKernel {
76  public:
BoostedTreesGetEnsembleStatesOp(OpKernelConstruction * context)77   explicit BoostedTreesGetEnsembleStatesOp(OpKernelConstruction* context)
78       : OpKernel(context) {}
79 
Compute(OpKernelContext * context)80   void Compute(OpKernelContext* context) override {
81     // Looks up the resource.
82     core::RefCountPtr<BoostedTreesEnsembleResource> tree_ensemble_resource;
83     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
84                                            &tree_ensemble_resource));
85     tf_shared_lock l(*tree_ensemble_resource->get_mutex());
86 
87     // Sets the outputs.
88     const int num_trees = tree_ensemble_resource->num_trees();
89     const int num_finalized_trees =
90         (num_trees <= 0 ||
91          tree_ensemble_resource->IsTreeFinalized(num_trees - 1))
92             ? num_trees
93             : num_trees - 1;
94     const int num_attempted_layers =
95         tree_ensemble_resource->GetNumLayersAttempted();
96 
97     // growing_metadata
98     Tensor* output_stamp_token_t = nullptr;
99     Tensor* output_num_trees_t = nullptr;
100     Tensor* output_num_finalized_trees_t = nullptr;
101     Tensor* output_num_attempted_layers_t = nullptr;
102     Tensor* output_last_layer_nodes_range_t = nullptr;
103 
104     OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
105                                                      &output_stamp_token_t));
106     OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape(),
107                                                      &output_num_trees_t));
108     OP_REQUIRES_OK(context,
109                    context->allocate_output(2, TensorShape(),
110                                             &output_num_finalized_trees_t));
111     OP_REQUIRES_OK(context,
112                    context->allocate_output(3, TensorShape(),
113                                             &output_num_attempted_layers_t));
114     OP_REQUIRES_OK(context, context->allocate_output(
115                                 4, {2}, &output_last_layer_nodes_range_t));
116 
117     output_stamp_token_t->scalar<int64>()() = tree_ensemble_resource->stamp();
118     output_num_trees_t->scalar<int32>()() = num_trees;
119     output_num_finalized_trees_t->scalar<int32>()() = num_finalized_trees;
120     output_num_attempted_layers_t->scalar<int32>()() = num_attempted_layers;
121 
122     int32 range_start;
123     int32 range_end;
124     tree_ensemble_resource->GetLastLayerNodesRange(&range_start, &range_end);
125 
126     output_last_layer_nodes_range_t->vec<int32>()(0) = range_start;
127     // For a completely empty ensemble, this will be 0. To make it a valid range
128     // we add this max cond.
129     output_last_layer_nodes_range_t->vec<int32>()(1) = std::max(1, range_end);
130   }
131 };
132 
133 REGISTER_KERNEL_BUILDER(
134     Name("BoostedTreesGetEnsembleStates").Device(DEVICE_CPU),
135     BoostedTreesGetEnsembleStatesOp);
136 
137 // Op for serializing a model.
138 class BoostedTreesSerializeEnsembleOp : public OpKernel {
139  public:
BoostedTreesSerializeEnsembleOp(OpKernelConstruction * context)140   explicit BoostedTreesSerializeEnsembleOp(OpKernelConstruction* context)
141       : OpKernel(context) {}
142 
Compute(OpKernelContext * context)143   void Compute(OpKernelContext* context) override {
144     core::RefCountPtr<BoostedTreesEnsembleResource> tree_ensemble_resource;
145     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
146                                            &tree_ensemble_resource));
147     tf_shared_lock l(*tree_ensemble_resource->get_mutex());
148     Tensor* output_stamp_token_t = nullptr;
149     OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
150                                                      &output_stamp_token_t));
151     output_stamp_token_t->scalar<int64>()() = tree_ensemble_resource->stamp();
152     Tensor* output_proto_t = nullptr;
153     OP_REQUIRES_OK(context,
154                    context->allocate_output(1, TensorShape(), &output_proto_t));
155     output_proto_t->scalar<tstring>()() =
156         tree_ensemble_resource->SerializeAsString();
157   }
158 };
159 
160 REGISTER_KERNEL_BUILDER(
161     Name("BoostedTreesSerializeEnsemble").Device(DEVICE_CPU),
162     BoostedTreesSerializeEnsembleOp);
163 
164 // Op for deserializing a tree ensemble variable from a checkpoint.
165 class BoostedTreesDeserializeEnsembleOp : public OpKernel {
166  public:
BoostedTreesDeserializeEnsembleOp(OpKernelConstruction * context)167   explicit BoostedTreesDeserializeEnsembleOp(OpKernelConstruction* context)
168       : OpKernel(context) {}
169 
Compute(OpKernelContext * context)170   void Compute(OpKernelContext* context) override {
171     core::RefCountPtr<BoostedTreesEnsembleResource> tree_ensemble_resource;
172     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
173                                            &tree_ensemble_resource));
174     mutex_lock l(*tree_ensemble_resource->get_mutex());
175 
176     // Get the stamp token.
177     const Tensor* stamp_token_t;
178     OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
179     int64 stamp_token = stamp_token_t->scalar<int64>()();
180 
181     // Get the tree ensemble proto.
182     const Tensor* tree_ensemble_serialized_t;
183     OP_REQUIRES_OK(context, context->input("tree_ensemble_serialized",
184                                            &tree_ensemble_serialized_t));
185     // Deallocate all the previous objects on the resource.
186     tree_ensemble_resource->Reset();
187     OP_REQUIRES(
188         context,
189         tree_ensemble_resource->InitFromSerialized(
190             tree_ensemble_serialized_t->scalar<tstring>()(), stamp_token),
191         errors::InvalidArgument("Unable to parse tree ensemble proto."));
192   }
193 };
194 
195 REGISTER_KERNEL_BUILDER(
196     Name("BoostedTreesDeserializeEnsemble").Device(DEVICE_CPU),
197     BoostedTreesDeserializeEnsembleOp);
198 
199 }  // namespace tensorflow
200