1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 #ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_ 16 #define TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_ 17 18 #include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h" 19 #include "tensorflow/contrib/boosted_trees/resources/stamped_resource.h" 20 #include "tensorflow/core/framework/resource_mgr.h" 21 #include "tensorflow/core/platform/mutex.h" 22 #include "tensorflow/core/platform/protobuf.h" 23 24 namespace tensorflow { 25 namespace boosted_trees { 26 namespace models { 27 28 // Keep a tree ensemble in memory for efficient evaluation and mutation. 29 class DecisionTreeEnsembleResource : public StampedResource { 30 public: 31 // Constructor. DecisionTreeEnsembleResource()32 explicit DecisionTreeEnsembleResource() 33 : decision_tree_ensemble_( 34 protobuf::Arena::CreateMessage< 35 boosted_trees::trees::DecisionTreeEnsembleConfig>(&arena_)) {} 36 DebugString()37 string DebugString() const override { 38 return strings::StrCat("GTFlowDecisionTreeEnsemble[size=", 39 decision_tree_ensemble_->trees_size(), "]"); 40 } 41 42 const boosted_trees::trees::DecisionTreeEnsembleConfig& decision_tree_ensemble()43 decision_tree_ensemble() const { 44 return *decision_tree_ensemble_; 45 } 46 num_trees()47 int32 num_trees() const { return decision_tree_ensemble_->trees_size(); } 48 InitFromSerialized(const string & serialized,const int64 stamp_token)49 bool InitFromSerialized(const string& serialized, const int64 stamp_token) { 50 CHECK_EQ(stamp(), -1) << "Must Reset before Init."; 51 if (ParseProtoUnlimited(decision_tree_ensemble_, serialized)) { 52 set_stamp(stamp_token); 53 return true; 54 } 55 return false; 56 } 57 SerializeAsString()58 string SerializeAsString() const { 59 return decision_tree_ensemble_->SerializeAsString(); 60 } 61 62 // Increment num_layers_attempted and num_trees_attempted in growing_metadata 63 // if the tree is finalized. IncrementAttempts()64 void IncrementAttempts() { 65 boosted_trees::trees::GrowingMetadata* const growing_metadata = 66 decision_tree_ensemble_->mutable_growing_metadata(); 67 growing_metadata->set_num_layers_attempted( 68 growing_metadata->num_layers_attempted() + 1); 69 const int num_trees = decision_tree_ensemble_->trees_size(); 70 if (num_trees <= 0 || LastTreeMetadata()->is_finalized()) { 71 growing_metadata->set_num_trees_attempted( 72 growing_metadata->num_trees_attempted() + 1); 73 } 74 } 75 AddNewTree(const float weight)76 boosted_trees::trees::DecisionTreeConfig* AddNewTree(const float weight) { 77 // Adding a tree as well as a weight and a tree_metadata. 78 decision_tree_ensemble_->add_tree_weights(weight); 79 boosted_trees::trees::DecisionTreeMetadata* const metadata = 80 decision_tree_ensemble_->add_tree_metadata(); 81 metadata->set_num_layers_grown(1); 82 return decision_tree_ensemble_->add_trees(); 83 } 84 RemoveLastTree()85 void RemoveLastTree() { 86 QCHECK_GT(decision_tree_ensemble_->trees_size(), 0); 87 decision_tree_ensemble_->mutable_trees()->RemoveLast(); 88 decision_tree_ensemble_->mutable_tree_weights()->RemoveLast(); 89 decision_tree_ensemble_->mutable_tree_metadata()->RemoveLast(); 90 } 91 LastTree()92 boosted_trees::trees::DecisionTreeConfig* LastTree() { 93 const int32 tree_size = decision_tree_ensemble_->trees_size(); 94 QCHECK_GT(tree_size, 0); 95 return decision_tree_ensemble_->mutable_trees(tree_size - 1); 96 } 97 LastTreeMetadata()98 boosted_trees::trees::DecisionTreeMetadata* LastTreeMetadata() { 99 const int32 metadata_size = decision_tree_ensemble_->tree_metadata_size(); 100 QCHECK_GT(metadata_size, 0); 101 return decision_tree_ensemble_->mutable_tree_metadata(metadata_size - 1); 102 } 103 104 // Retrieves tree weights and returns as a vector. GetTreeWeights()105 std::vector<float> GetTreeWeights() const { 106 return {decision_tree_ensemble_->tree_weights().begin(), 107 decision_tree_ensemble_->tree_weights().end()}; 108 } 109 GetTreeWeight(const int32 index)110 float GetTreeWeight(const int32 index) const { 111 return decision_tree_ensemble_->tree_weights(index); 112 } 113 MaybeAddUsedHandler(const int32 handler_id)114 void MaybeAddUsedHandler(const int32 handler_id) { 115 protobuf::RepeatedField<protobuf_int64>* used_ids = 116 decision_tree_ensemble_->mutable_growing_metadata() 117 ->mutable_used_handler_ids(); 118 protobuf::RepeatedField<protobuf_int64>::iterator first = 119 std::lower_bound(used_ids->begin(), used_ids->end(), handler_id); 120 if (first == used_ids->end()) { 121 used_ids->Add(handler_id); 122 return; 123 } 124 if (handler_id == *first) { 125 // It is a duplicate entry. 126 return; 127 } 128 used_ids->Add(handler_id); 129 // Keep the list of used handlers sorted. 130 std::sort(used_ids->begin(), used_ids->end()); 131 } 132 GetUsedHandlers()133 std::vector<int64> GetUsedHandlers() const { 134 std::vector<int64> result; 135 result.reserve( 136 decision_tree_ensemble_->growing_metadata().used_handler_ids().size()); 137 for (int64 h : 138 decision_tree_ensemble_->growing_metadata().used_handler_ids()) { 139 result.push_back(h); 140 } 141 return result; 142 } 143 144 // Sets the weight of i'th tree, and increment num_updates in tree_metadata. SetTreeWeight(const int32 index,const float weight,const int32 increment_num_updates)145 void SetTreeWeight(const int32 index, const float weight, 146 const int32 increment_num_updates) { 147 QCHECK_GE(index, 0); 148 QCHECK_LT(index, num_trees()); 149 decision_tree_ensemble_->set_tree_weights(index, weight); 150 if (increment_num_updates != 0) { 151 const int32 num_updates = decision_tree_ensemble_->tree_metadata(index) 152 .num_tree_weight_updates(); 153 decision_tree_ensemble_->mutable_tree_metadata(index) 154 ->set_num_tree_weight_updates(num_updates + increment_num_updates); 155 } 156 } 157 158 // Resets the resource and frees the protos in arena. 159 // Caller needs to hold the mutex lock while calling this. Reset()160 virtual void Reset() { 161 // Reset stamp. 162 set_stamp(-1); 163 164 // Clear tree ensemle. 165 arena_.Reset(); 166 CHECK_EQ(0, arena_.SpaceAllocated()); 167 decision_tree_ensemble_ = protobuf::Arena::CreateMessage< 168 boosted_trees::trees::DecisionTreeEnsembleConfig>(&arena_); 169 } 170 get_mutex()171 mutex* get_mutex() { return &mu_; } 172 173 protected: 174 protobuf::Arena arena_; 175 mutex mu_; 176 boosted_trees::trees::DecisionTreeEnsembleConfig* decision_tree_ensemble_; 177 }; 178 179 } // namespace models 180 } // namespace boosted_trees 181 } // namespace tensorflow 182 183 #endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_ 184