1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_RESOURCES_H_ 17 #define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_RESOURCES_H_ 18 19 #include "tensorflow/core/framework/resource_mgr.h" 20 #include "tensorflow/core/kernels/boosted_trees/tree_helper.h" 21 #include "tensorflow/core/platform/mutex.h" 22 #include "tensorflow/core/platform/protobuf.h" 23 24 namespace tensorflow { 25 26 // Forward declaration for proto class TreeEnsemble 27 namespace boosted_trees { 28 class TreeEnsemble; 29 class Node; 30 } // namespace boosted_trees 31 32 // A StampedResource is a resource that has a stamp token associated with it. 33 // Before reading from or applying updates to the resource, the stamp should 34 // be checked to verify that the update is not stale. 35 class StampedResource : public ResourceBase { 36 public: StampedResource()37 StampedResource() : stamp_(-1) {} 38 is_stamp_valid(int64_t stamp)39 bool is_stamp_valid(int64_t stamp) const { return stamp_ == stamp; } 40 stamp()41 int64 stamp() const { return stamp_; } set_stamp(int64_t stamp)42 void set_stamp(int64_t stamp) { stamp_ = stamp; } 43 44 private: 45 int64 stamp_; 46 }; 47 48 // Keep a tree ensemble in memory for efficient evaluation and mutation. 49 class BoostedTreesEnsembleResource : public StampedResource { 50 public: 51 BoostedTreesEnsembleResource(); 52 53 string DebugString() const override; 54 55 bool InitFromSerialized(const string& serialized, const int64_t stamp_token); 56 57 string SerializeAsString() const; 58 59 int32 num_trees() const; 60 61 // Find the next node to which the example (specified by index_in_batch) 62 // traverses down from the current node indicated by tree_id and node_id. 63 // Args: 64 // tree_id: the index of the tree in the ensemble. 65 // node_id: the index of the node within the tree. 66 // index_in_batch: the index of the example within the batch (relevant to 67 // the index of the row to read in each bucketized_features). 68 // bucketized_features: vector of feature Vectors. 69 int32 next_node( 70 const int32_t tree_id, const int32_t node_id, 71 const int32_t index_in_batch, 72 const std::vector<TTypes<int32>::ConstMatrix>& bucketized_features) const; 73 74 std::vector<float> node_value(const int32_t tree_id, 75 const int32_t node_id) const; 76 77 void set_node_value(const int32_t tree_id, const int32_t node_id, 78 const float logits); 79 80 int32 GetNumLayersGrown(const int32_t tree_id) const; 81 82 void SetNumLayersGrown(const int32_t tree_id, int32_t new_num_layers) const; 83 84 void UpdateLastLayerNodesRange(const int32_t node_range_start, 85 int32_t node_range_end) const; 86 87 void GetLastLayerNodesRange(int32* node_range_start, 88 int32* node_range_end) const; 89 90 int64 GetNumNodes(const int32_t tree_id); 91 92 void UpdateGrowingMetadata() const; 93 94 int32 GetNumLayersAttempted(); 95 96 bool is_leaf(const int32_t tree_id, const int32_t node_id) const; 97 98 int32 feature_id(const int32_t tree_id, const int32_t node_id) const; 99 100 int32 bucket_threshold(const int32_t tree_id, const int32_t node_id) const; 101 102 int32 left_id(const int32_t tree_id, const int32_t node_id) const; 103 104 int32 right_id(const int32_t tree_id, const int32_t node_id) const; 105 106 // Add a tree to the ensemble and returns a new tree_id. 107 int32 AddNewTree(const float weight, const int32_t logits_dimension); 108 109 // Adds new tree with one node to the ensemble and sets node's value to logits 110 int32 AddNewTreeWithLogits(const float weight, 111 const std::vector<float>& logits, 112 const int32_t logits_dimension); 113 114 // Grows the tree by adding a bucketized split and leaves. 115 void AddBucketizedSplitNode( 116 const int32_t tree_id, 117 const std::pair<int32, boosted_trees::SplitCandidate>& split_entry, 118 const int32_t logits_dimension, int32* left_node_id, 119 int32* right_node_id); 120 121 // Grows the tree by adding a categorical split and leaves. 122 void AddCategoricalSplitNode( 123 const int32_t tree_id, 124 const std::pair<int32, boosted_trees::SplitCandidate>& split_entry, 125 const int32_t logits_dimension, int32* left_node_id, 126 int32* right_node_id); 127 128 // Retrieves tree weights and returns as a vector. 129 // It involves a copy, so should be called only sparingly (like once per 130 // iteration, not per example). 131 std::vector<float> GetTreeWeights() const; 132 133 float GetTreeWeight(const int32_t tree_id) const; 134 135 float IsTreeFinalized(const int32_t tree_id) const; 136 137 float IsTreePostPruned(const int32_t tree_id) const; 138 139 void SetIsFinalized(const int32_t tree_id, const bool is_finalized); 140 141 // Sets the weight of i'th tree. 142 void SetTreeWeight(const int32_t tree_id, const float weight); 143 144 // Resets the resource and frees the protos in arena. 145 // Caller needs to hold the mutex lock while calling this. 146 virtual void Reset(); 147 148 void PostPruneTree(const int32_t current_tree, 149 const int32_t logits_dimension); 150 151 // For a given node, returns the id in a pruned tree, as well as correction 152 // to the cached prediction that should be applied. If tree was not 153 // post-pruned, current_node_id will be equal to initial_node_id and logit 154 // update will be equal to zero. 155 void GetPostPruneCorrection(const int32_t tree_id, 156 const int32_t initial_node_id, 157 int32* current_node_id, 158 std::vector<float>* logit_updates) const; get_mutex()159 mutex* get_mutex() { return &mu_; } 160 161 private: 162 // Helper method to check whether a node is a terminal node in that it 163 // only has leaf nodes as children. 164 bool IsTerminalSplitNode(const int32_t tree_id, const int32_t node_id) const; 165 166 // For each pruned node, finds the leaf where it finally ended up and 167 // calculates the total update from that pruned node prediction. 168 void CalculateParentAndLogitUpdate( 169 const int32_t start_node_id, 170 const std::vector<std::pair<int32, std::vector<float>>>& nodes_change, 171 int32* parent_id, std::vector<float>* change) const; 172 173 // Helper method to collect the information to be used to prune some nodes in 174 // the tree. 175 void RecursivelyDoPostPrunePreparation( 176 const int32_t tree_id, const int32_t node_id, 177 std::vector<int32>* nodes_to_delete, 178 std::vector<std::pair<int32, std::vector<float>>>* nodes_meta); 179 180 protected: 181 protobuf::Arena arena_; 182 mutex mu_; 183 boosted_trees::TreeEnsemble* tree_ensemble_; 184 185 boosted_trees::Node* AddLeafNodes( 186 int32_t tree_id, 187 const std::pair<int32, boosted_trees::SplitCandidate>& split_entry, 188 const int32_t logits_dimension, int32* left_node_id, 189 int32* right_node_id); 190 }; 191 192 } // namespace tensorflow 193 194 #endif // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_RESOURCES_H_ 195