1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_RESOURCES_H_ 17 #define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_RESOURCES_H_ 18 19 #include "tensorflow/core/framework/resource_mgr.h" 20 #include "tensorflow/core/kernels/boosted_trees/tree_helper.h" 21 #include "tensorflow/core/platform/mutex.h" 22 #include "tensorflow/core/platform/protobuf.h" 23 24 namespace tensorflow { 25 26 // Forward declaration for proto class TreeEnsemble 27 namespace boosted_trees { 28 class TreeEnsemble; 29 class Node; 30 } // namespace boosted_trees 31 32 // A StampedResource is a resource that has a stamp token associated with it. 33 // Before reading from or applying updates to the resource, the stamp should 34 // be checked to verify that the update is not stale. 35 class StampedResource : public ResourceBase { 36 public: StampedResource()37 StampedResource() : stamp_(-1) {} 38 is_stamp_valid(int64 stamp)39 bool is_stamp_valid(int64 stamp) const { return stamp_ == stamp; } 40 stamp()41 int64 stamp() const { return stamp_; } set_stamp(int64 stamp)42 void set_stamp(int64 stamp) { stamp_ = stamp; } 43 44 private: 45 int64 stamp_; 46 }; 47 48 // Keep a tree ensemble in memory for efficient evaluation and mutation. 49 class BoostedTreesEnsembleResource : public StampedResource { 50 public: 51 BoostedTreesEnsembleResource(); 52 53 string DebugString() const override; 54 55 bool InitFromSerialized(const string& serialized, const int64 stamp_token); 56 57 string SerializeAsString() const; 58 59 int32 num_trees() const; 60 61 // Find the next node to which the example (specified by index_in_batch) 62 // traverses down from the current node indicated by tree_id and node_id. 63 // Args: 64 // tree_id: the index of the tree in the ensemble. 65 // node_id: the index of the node within the tree. 66 // index_in_batch: the index of the example within the batch (relevant to 67 // the index of the row to read in each bucketized_features). 68 // bucketized_features: vector of feature Vectors. 69 int32 next_node( 70 const int32 tree_id, const int32 node_id, const int32 index_in_batch, 71 const std::vector<TTypes<int32>::ConstMatrix>& bucketized_features) const; 72 73 std::vector<float> node_value(const int32 tree_id, const int32 node_id) const; 74 75 void set_node_value(const int32 tree_id, const int32 node_id, 76 const float logits); 77 78 int32 GetNumLayersGrown(const int32 tree_id) const; 79 80 void SetNumLayersGrown(const int32 tree_id, int32 new_num_layers) const; 81 82 void UpdateLastLayerNodesRange(const int32 node_range_start, 83 int32 node_range_end) const; 84 85 void GetLastLayerNodesRange(int32* node_range_start, 86 int32* node_range_end) const; 87 88 int64 GetNumNodes(const int32 tree_id); 89 90 void UpdateGrowingMetadata() const; 91 92 int32 GetNumLayersAttempted(); 93 94 bool is_leaf(const int32 tree_id, const int32 node_id) const; 95 96 int32 feature_id(const int32 tree_id, const int32 node_id) const; 97 98 int32 bucket_threshold(const int32 tree_id, const int32 node_id) const; 99 100 int32 left_id(const int32 tree_id, const int32 node_id) const; 101 102 int32 right_id(const int32 tree_id, const int32 node_id) const; 103 104 // Add a tree to the ensemble and returns a new tree_id. 105 int32 AddNewTree(const float weight, const int32 logits_dimension); 106 107 // Adds new tree with one node to the ensemble and sets node's value to logits 108 int32 AddNewTreeWithLogits(const float weight, 109 const std::vector<float>& logits, 110 const int32 logits_dimension); 111 112 // Grows the tree by adding a bucketized split and leaves. 113 void AddBucketizedSplitNode( 114 const int32 tree_id, 115 const std::pair<int32, boosted_trees::SplitCandidate>& split_entry, 116 const int32 logits_dimension, int32* left_node_id, int32* right_node_id); 117 118 // Grows the tree by adding a categorical split and leaves. 119 void AddCategoricalSplitNode( 120 const int32 tree_id, 121 const std::pair<int32, boosted_trees::SplitCandidate>& split_entry, 122 const int32 logits_dimension, int32* left_node_id, int32* right_node_id); 123 124 // Retrieves tree weights and returns as a vector. 125 // It involves a copy, so should be called only sparingly (like once per 126 // iteration, not per example). 127 std::vector<float> GetTreeWeights() const; 128 129 float GetTreeWeight(const int32 tree_id) const; 130 131 float IsTreeFinalized(const int32 tree_id) const; 132 133 float IsTreePostPruned(const int32 tree_id) const; 134 135 void SetIsFinalized(const int32 tree_id, const bool is_finalized); 136 137 // Sets the weight of i'th tree. 138 void SetTreeWeight(const int32 tree_id, const float weight); 139 140 // Resets the resource and frees the protos in arena. 141 // Caller needs to hold the mutex lock while calling this. 142 virtual void Reset(); 143 144 void PostPruneTree(const int32 current_tree, const int32 logits_dimension); 145 146 // For a given node, returns the id in a pruned tree, as well as correction 147 // to the cached prediction that should be applied. If tree was not 148 // post-pruned, current_node_id will be equal to initial_node_id and logit 149 // update will be equal to zero. 150 void GetPostPruneCorrection(const int32 tree_id, const int32 initial_node_id, 151 int32* current_node_id, 152 std::vector<float>* logit_updates) const; get_mutex()153 mutex* get_mutex() { return &mu_; } 154 155 private: 156 // Helper method to check whether a node is a terminal node in that it 157 // only has leaf nodes as children. 158 bool IsTerminalSplitNode(const int32 tree_id, const int32 node_id) const; 159 160 // For each pruned node, finds the leaf where it finally ended up and 161 // calculates the total update from that pruned node prediction. 162 void CalculateParentAndLogitUpdate( 163 const int32 start_node_id, 164 const std::vector<std::pair<int32, std::vector<float>>>& nodes_change, 165 int32* parent_id, std::vector<float>* change) const; 166 167 // Helper method to collect the information to be used to prune some nodes in 168 // the tree. 169 void RecursivelyDoPostPrunePreparation( 170 const int32 tree_id, const int32 node_id, 171 std::vector<int32>* nodes_to_delete, 172 std::vector<std::pair<int32, std::vector<float>>>* nodes_meta); 173 174 protected: 175 protobuf::Arena arena_; 176 mutex mu_; 177 boosted_trees::TreeEnsemble* tree_ensemble_; 178 179 boosted_trees::Node* AddLeafNodes( 180 int32 tree_id, 181 const std::pair<int32, boosted_trees::SplitCandidate>& split_entry, 182 const int32 logits_dimension, int32* left_node_id, int32* right_node_id); 183 }; 184 185 } // namespace tensorflow 186 187 #endif // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_RESOURCES_H_ 188