1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <memory> 17 #include <string> 18 19 #include "tensorflow/core/framework/op_kernel.h" 20 #include "tensorflow/core/framework/resource_mgr.h" 21 #include "tensorflow/core/framework/tensor.h" 22 #include "tensorflow/core/framework/tensor_shape.h" 23 #include "tensorflow/core/framework/tensor_types.h" 24 #include "tensorflow/core/kernels/boosted_trees/resources.h" 25 #include "tensorflow/core/lib/core/refcount.h" 26 27 namespace tensorflow { 28 29 REGISTER_RESOURCE_HANDLE_KERNEL(BoostedTreesEnsembleResource); 30 31 REGISTER_KERNEL_BUILDER( 32 Name("IsBoostedTreesEnsembleInitialized").Device(DEVICE_CPU), 33 IsResourceInitialized<BoostedTreesEnsembleResource>); 34 35 // Creates a tree ensemble resource. 36 class BoostedTreesCreateEnsembleOp : public OpKernel { 37 public: BoostedTreesCreateEnsembleOp(OpKernelConstruction * context)38 explicit BoostedTreesCreateEnsembleOp(OpKernelConstruction* context) 39 : OpKernel(context) {} 40 Compute(OpKernelContext * context)41 void Compute(OpKernelContext* context) override { 42 // Get the stamp token. 43 const Tensor* stamp_token_t; 44 OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t)); 45 int64_t stamp_token = stamp_token_t->scalar<int64>()(); 46 47 // Get the tree ensemble proto. 48 const Tensor* tree_ensemble_serialized_t; 49 OP_REQUIRES_OK(context, context->input("tree_ensemble_serialized", 50 &tree_ensemble_serialized_t)); 51 std::unique_ptr<BoostedTreesEnsembleResource> result( 52 new BoostedTreesEnsembleResource()); 53 if (!result->InitFromSerialized( 54 tree_ensemble_serialized_t->scalar<tstring>()(), stamp_token)) { 55 result->Unref(); 56 result.release(); // Needed due to the `->Unref` above, to prevent UAF 57 OP_REQUIRES( 58 context, false, 59 errors::InvalidArgument("Unable to parse tree ensemble proto.")); 60 } 61 62 // Only create one, if one does not exist already. Report status for all 63 // other exceptions. 64 auto status = 65 CreateResource(context, HandleFromInput(context, 0), result.release()); 66 if (status.code() != tensorflow::error::ALREADY_EXISTS) { 67 OP_REQUIRES_OK(context, status); 68 } 69 } 70 }; 71 72 REGISTER_KERNEL_BUILDER(Name("BoostedTreesCreateEnsemble").Device(DEVICE_CPU), 73 BoostedTreesCreateEnsembleOp); 74 75 // Op for retrieving some model states (needed for training). 76 class BoostedTreesGetEnsembleStatesOp : public OpKernel { 77 public: BoostedTreesGetEnsembleStatesOp(OpKernelConstruction * context)78 explicit BoostedTreesGetEnsembleStatesOp(OpKernelConstruction* context) 79 : OpKernel(context) {} 80 Compute(OpKernelContext * context)81 void Compute(OpKernelContext* context) override { 82 // Looks up the resource. 83 core::RefCountPtr<BoostedTreesEnsembleResource> tree_ensemble_resource; 84 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 85 &tree_ensemble_resource)); 86 tf_shared_lock l(*tree_ensemble_resource->get_mutex()); 87 88 // Sets the outputs. 89 const int num_trees = tree_ensemble_resource->num_trees(); 90 const int num_finalized_trees = 91 (num_trees <= 0 || 92 tree_ensemble_resource->IsTreeFinalized(num_trees - 1)) 93 ? num_trees 94 : num_trees - 1; 95 const int num_attempted_layers = 96 tree_ensemble_resource->GetNumLayersAttempted(); 97 98 // growing_metadata 99 Tensor* output_stamp_token_t = nullptr; 100 Tensor* output_num_trees_t = nullptr; 101 Tensor* output_num_finalized_trees_t = nullptr; 102 Tensor* output_num_attempted_layers_t = nullptr; 103 Tensor* output_last_layer_nodes_range_t = nullptr; 104 105 OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(), 106 &output_stamp_token_t)); 107 OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape(), 108 &output_num_trees_t)); 109 OP_REQUIRES_OK(context, 110 context->allocate_output(2, TensorShape(), 111 &output_num_finalized_trees_t)); 112 OP_REQUIRES_OK(context, 113 context->allocate_output(3, TensorShape(), 114 &output_num_attempted_layers_t)); 115 OP_REQUIRES_OK(context, context->allocate_output( 116 4, {2}, &output_last_layer_nodes_range_t)); 117 118 output_stamp_token_t->scalar<int64>()() = tree_ensemble_resource->stamp(); 119 output_num_trees_t->scalar<int32>()() = num_trees; 120 output_num_finalized_trees_t->scalar<int32>()() = num_finalized_trees; 121 output_num_attempted_layers_t->scalar<int32>()() = num_attempted_layers; 122 123 int32_t range_start; 124 int32_t range_end; 125 tree_ensemble_resource->GetLastLayerNodesRange(&range_start, &range_end); 126 127 output_last_layer_nodes_range_t->vec<int32>()(0) = range_start; 128 // For a completely empty ensemble, this will be 0. To make it a valid range 129 // we add this max cond. 130 output_last_layer_nodes_range_t->vec<int32>()(1) = std::max(1, range_end); 131 } 132 }; 133 134 REGISTER_KERNEL_BUILDER( 135 Name("BoostedTreesGetEnsembleStates").Device(DEVICE_CPU), 136 BoostedTreesGetEnsembleStatesOp); 137 138 // Op for serializing a model. 139 class BoostedTreesSerializeEnsembleOp : public OpKernel { 140 public: BoostedTreesSerializeEnsembleOp(OpKernelConstruction * context)141 explicit BoostedTreesSerializeEnsembleOp(OpKernelConstruction* context) 142 : OpKernel(context) {} 143 Compute(OpKernelContext * context)144 void Compute(OpKernelContext* context) override { 145 core::RefCountPtr<BoostedTreesEnsembleResource> tree_ensemble_resource; 146 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 147 &tree_ensemble_resource)); 148 tf_shared_lock l(*tree_ensemble_resource->get_mutex()); 149 Tensor* output_stamp_token_t = nullptr; 150 OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(), 151 &output_stamp_token_t)); 152 output_stamp_token_t->scalar<int64>()() = tree_ensemble_resource->stamp(); 153 Tensor* output_proto_t = nullptr; 154 OP_REQUIRES_OK(context, 155 context->allocate_output(1, TensorShape(), &output_proto_t)); 156 output_proto_t->scalar<tstring>()() = 157 tree_ensemble_resource->SerializeAsString(); 158 } 159 }; 160 161 REGISTER_KERNEL_BUILDER( 162 Name("BoostedTreesSerializeEnsemble").Device(DEVICE_CPU), 163 BoostedTreesSerializeEnsembleOp); 164 165 // Op for deserializing a tree ensemble variable from a checkpoint. 166 class BoostedTreesDeserializeEnsembleOp : public OpKernel { 167 public: BoostedTreesDeserializeEnsembleOp(OpKernelConstruction * context)168 explicit BoostedTreesDeserializeEnsembleOp(OpKernelConstruction* context) 169 : OpKernel(context) {} 170 Compute(OpKernelContext * context)171 void Compute(OpKernelContext* context) override { 172 core::RefCountPtr<BoostedTreesEnsembleResource> tree_ensemble_resource; 173 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 174 &tree_ensemble_resource)); 175 mutex_lock l(*tree_ensemble_resource->get_mutex()); 176 177 // Get the stamp token. 178 const Tensor* stamp_token_t; 179 OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t)); 180 int64_t stamp_token = stamp_token_t->scalar<int64>()(); 181 182 // Get the tree ensemble proto. 183 const Tensor* tree_ensemble_serialized_t; 184 OP_REQUIRES_OK(context, context->input("tree_ensemble_serialized", 185 &tree_ensemble_serialized_t)); 186 // Deallocate all the previous objects on the resource. 187 tree_ensemble_resource->Reset(); 188 OP_REQUIRES( 189 context, 190 tree_ensemble_resource->InitFromSerialized( 191 tree_ensemble_serialized_t->scalar<tstring>()(), stamp_token), 192 errors::InvalidArgument("Unable to parse tree ensemble proto.")); 193 } 194 }; 195 196 REGISTER_KERNEL_BUILDER( 197 Name("BoostedTreesDeserializeEnsemble").Device(DEVICE_CPU), 198 BoostedTreesDeserializeEnsembleOp); 199 200 } // namespace tensorflow 201