1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <memory> 17 #include <string> 18 19 #include "tensorflow/core/framework/op_kernel.h" 20 #include "tensorflow/core/framework/resource_mgr.h" 21 #include "tensorflow/core/framework/tensor.h" 22 #include "tensorflow/core/framework/tensor_shape.h" 23 #include "tensorflow/core/framework/tensor_types.h" 24 #include "tensorflow/core/kernels/boosted_trees/resources.h" 25 #include "tensorflow/core/lib/core/refcount.h" 26 27 namespace tensorflow { 28 29 REGISTER_RESOURCE_HANDLE_KERNEL(BoostedTreesEnsembleResource); 30 31 REGISTER_KERNEL_BUILDER( 32 Name("IsBoostedTreesEnsembleInitialized").Device(DEVICE_CPU), 33 IsResourceInitialized<BoostedTreesEnsembleResource>); 34 35 // Creates a tree ensemble resource. 36 class BoostedTreesCreateEnsembleOp : public OpKernel { 37 public: BoostedTreesCreateEnsembleOp(OpKernelConstruction * context)38 explicit BoostedTreesCreateEnsembleOp(OpKernelConstruction* context) 39 : OpKernel(context) {} 40 Compute(OpKernelContext * context)41 void Compute(OpKernelContext* context) override { 42 // Get the stamp token. 43 const Tensor* stamp_token_t; 44 OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t)); 45 int64 stamp_token = stamp_token_t->scalar<int64>()(); 46 47 // Get the tree ensemble proto. 48 const Tensor* tree_ensemble_serialized_t; 49 OP_REQUIRES_OK(context, context->input("tree_ensemble_serialized", 50 &tree_ensemble_serialized_t)); 51 std::unique_ptr<BoostedTreesEnsembleResource> result( 52 new BoostedTreesEnsembleResource()); 53 if (!result->InitFromSerialized( 54 tree_ensemble_serialized_t->scalar<tstring>()(), stamp_token)) { 55 result->Unref(); 56 OP_REQUIRES( 57 context, false, 58 errors::InvalidArgument("Unable to parse tree ensemble proto.")); 59 } 60 61 // Only create one, if one does not exist already. Report status for all 62 // other exceptions. 63 auto status = 64 CreateResource(context, HandleFromInput(context, 0), result.release()); 65 if (status.code() != tensorflow::error::ALREADY_EXISTS) { 66 OP_REQUIRES_OK(context, status); 67 } 68 } 69 }; 70 71 REGISTER_KERNEL_BUILDER(Name("BoostedTreesCreateEnsemble").Device(DEVICE_CPU), 72 BoostedTreesCreateEnsembleOp); 73 74 // Op for retrieving some model states (needed for training). 75 class BoostedTreesGetEnsembleStatesOp : public OpKernel { 76 public: BoostedTreesGetEnsembleStatesOp(OpKernelConstruction * context)77 explicit BoostedTreesGetEnsembleStatesOp(OpKernelConstruction* context) 78 : OpKernel(context) {} 79 Compute(OpKernelContext * context)80 void Compute(OpKernelContext* context) override { 81 // Looks up the resource. 82 core::RefCountPtr<BoostedTreesEnsembleResource> tree_ensemble_resource; 83 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 84 &tree_ensemble_resource)); 85 tf_shared_lock l(*tree_ensemble_resource->get_mutex()); 86 87 // Sets the outputs. 88 const int num_trees = tree_ensemble_resource->num_trees(); 89 const int num_finalized_trees = 90 (num_trees <= 0 || 91 tree_ensemble_resource->IsTreeFinalized(num_trees - 1)) 92 ? num_trees 93 : num_trees - 1; 94 const int num_attempted_layers = 95 tree_ensemble_resource->GetNumLayersAttempted(); 96 97 // growing_metadata 98 Tensor* output_stamp_token_t = nullptr; 99 Tensor* output_num_trees_t = nullptr; 100 Tensor* output_num_finalized_trees_t = nullptr; 101 Tensor* output_num_attempted_layers_t = nullptr; 102 Tensor* output_last_layer_nodes_range_t = nullptr; 103 104 OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(), 105 &output_stamp_token_t)); 106 OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape(), 107 &output_num_trees_t)); 108 OP_REQUIRES_OK(context, 109 context->allocate_output(2, TensorShape(), 110 &output_num_finalized_trees_t)); 111 OP_REQUIRES_OK(context, 112 context->allocate_output(3, TensorShape(), 113 &output_num_attempted_layers_t)); 114 OP_REQUIRES_OK(context, context->allocate_output( 115 4, {2}, &output_last_layer_nodes_range_t)); 116 117 output_stamp_token_t->scalar<int64>()() = tree_ensemble_resource->stamp(); 118 output_num_trees_t->scalar<int32>()() = num_trees; 119 output_num_finalized_trees_t->scalar<int32>()() = num_finalized_trees; 120 output_num_attempted_layers_t->scalar<int32>()() = num_attempted_layers; 121 122 int32 range_start; 123 int32 range_end; 124 tree_ensemble_resource->GetLastLayerNodesRange(&range_start, &range_end); 125 126 output_last_layer_nodes_range_t->vec<int32>()(0) = range_start; 127 // For a completely empty ensemble, this will be 0. To make it a valid range 128 // we add this max cond. 129 output_last_layer_nodes_range_t->vec<int32>()(1) = std::max(1, range_end); 130 } 131 }; 132 133 REGISTER_KERNEL_BUILDER( 134 Name("BoostedTreesGetEnsembleStates").Device(DEVICE_CPU), 135 BoostedTreesGetEnsembleStatesOp); 136 137 // Op for serializing a model. 138 class BoostedTreesSerializeEnsembleOp : public OpKernel { 139 public: BoostedTreesSerializeEnsembleOp(OpKernelConstruction * context)140 explicit BoostedTreesSerializeEnsembleOp(OpKernelConstruction* context) 141 : OpKernel(context) {} 142 Compute(OpKernelContext * context)143 void Compute(OpKernelContext* context) override { 144 core::RefCountPtr<BoostedTreesEnsembleResource> tree_ensemble_resource; 145 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 146 &tree_ensemble_resource)); 147 tf_shared_lock l(*tree_ensemble_resource->get_mutex()); 148 Tensor* output_stamp_token_t = nullptr; 149 OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(), 150 &output_stamp_token_t)); 151 output_stamp_token_t->scalar<int64>()() = tree_ensemble_resource->stamp(); 152 Tensor* output_proto_t = nullptr; 153 OP_REQUIRES_OK(context, 154 context->allocate_output(1, TensorShape(), &output_proto_t)); 155 output_proto_t->scalar<tstring>()() = 156 tree_ensemble_resource->SerializeAsString(); 157 } 158 }; 159 160 REGISTER_KERNEL_BUILDER( 161 Name("BoostedTreesSerializeEnsemble").Device(DEVICE_CPU), 162 BoostedTreesSerializeEnsembleOp); 163 164 // Op for deserializing a tree ensemble variable from a checkpoint. 165 class BoostedTreesDeserializeEnsembleOp : public OpKernel { 166 public: BoostedTreesDeserializeEnsembleOp(OpKernelConstruction * context)167 explicit BoostedTreesDeserializeEnsembleOp(OpKernelConstruction* context) 168 : OpKernel(context) {} 169 Compute(OpKernelContext * context)170 void Compute(OpKernelContext* context) override { 171 core::RefCountPtr<BoostedTreesEnsembleResource> tree_ensemble_resource; 172 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 173 &tree_ensemble_resource)); 174 mutex_lock l(*tree_ensemble_resource->get_mutex()); 175 176 // Get the stamp token. 177 const Tensor* stamp_token_t; 178 OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t)); 179 int64 stamp_token = stamp_token_t->scalar<int64>()(); 180 181 // Get the tree ensemble proto. 182 const Tensor* tree_ensemble_serialized_t; 183 OP_REQUIRES_OK(context, context->input("tree_ensemble_serialized", 184 &tree_ensemble_serialized_t)); 185 // Deallocate all the previous objects on the resource. 186 tree_ensemble_resource->Reset(); 187 OP_REQUIRES( 188 context, 189 tree_ensemble_resource->InitFromSerialized( 190 tree_ensemble_serialized_t->scalar<tstring>()(), stamp_token), 191 errors::InvalidArgument("Unable to parse tree ensemble proto.")); 192 } 193 }; 194 195 REGISTER_KERNEL_BUILDER( 196 Name("BoostedTreesDeserializeEnsemble").Device(DEVICE_CPU), 197 BoostedTreesDeserializeEnsembleOp); 198 199 } // namespace tensorflow 200