• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_RESOURCES_H_
17 #define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_RESOURCES_H_
18 
19 #include "tensorflow/core/framework/resource_mgr.h"
20 #include "tensorflow/core/kernels/boosted_trees/tree_helper.h"
21 #include "tensorflow/core/platform/mutex.h"
22 #include "tensorflow/core/platform/protobuf.h"
23 
24 namespace tensorflow {
25 
26 // Forward declaration for proto class TreeEnsemble
27 namespace boosted_trees {
28 class TreeEnsemble;
29 class Node;
30 }  // namespace boosted_trees
31 
32 // A StampedResource is a resource that has a stamp token associated with it.
33 // Before reading from or applying updates to the resource, the stamp should
34 // be checked to verify that the update is not stale.
35 class StampedResource : public ResourceBase {
36  public:
StampedResource()37   StampedResource() : stamp_(-1) {}
38 
is_stamp_valid(int64_t stamp)39   bool is_stamp_valid(int64_t stamp) const { return stamp_ == stamp; }
40 
stamp()41   int64 stamp() const { return stamp_; }
set_stamp(int64_t stamp)42   void set_stamp(int64_t stamp) { stamp_ = stamp; }
43 
44  private:
45   int64 stamp_;
46 };
47 
48 // Keep a tree ensemble in memory for efficient evaluation and mutation.
49 class BoostedTreesEnsembleResource : public StampedResource {
50  public:
51   BoostedTreesEnsembleResource();
52 
53   string DebugString() const override;
54 
55   bool InitFromSerialized(const string& serialized, const int64_t stamp_token);
56 
57   string SerializeAsString() const;
58 
59   int32 num_trees() const;
60 
61   // Find the next node to which the example (specified by index_in_batch)
62   // traverses down from the current node indicated by tree_id and node_id.
63   // Args:
64   //   tree_id: the index of the tree in the ensemble.
65   //   node_id: the index of the node within the tree.
66   //   index_in_batch: the index of the example within the batch (relevant to
67   //       the index of the row to read in each bucketized_features).
68   //   bucketized_features: vector of feature Vectors.
69   int32 next_node(
70       const int32_t tree_id, const int32_t node_id,
71       const int32_t index_in_batch,
72       const std::vector<TTypes<int32>::ConstMatrix>& bucketized_features) const;
73 
74   std::vector<float> node_value(const int32_t tree_id,
75                                 const int32_t node_id) const;
76 
77   void set_node_value(const int32_t tree_id, const int32_t node_id,
78                       const float logits);
79 
80   int32 GetNumLayersGrown(const int32_t tree_id) const;
81 
82   void SetNumLayersGrown(const int32_t tree_id, int32_t new_num_layers) const;
83 
84   void UpdateLastLayerNodesRange(const int32_t node_range_start,
85                                  int32_t node_range_end) const;
86 
87   void GetLastLayerNodesRange(int32* node_range_start,
88                               int32* node_range_end) const;
89 
90   int64 GetNumNodes(const int32_t tree_id);
91 
92   void UpdateGrowingMetadata() const;
93 
94   int32 GetNumLayersAttempted();
95 
96   bool is_leaf(const int32_t tree_id, const int32_t node_id) const;
97 
98   int32 feature_id(const int32_t tree_id, const int32_t node_id) const;
99 
100   int32 bucket_threshold(const int32_t tree_id, const int32_t node_id) const;
101 
102   int32 left_id(const int32_t tree_id, const int32_t node_id) const;
103 
104   int32 right_id(const int32_t tree_id, const int32_t node_id) const;
105 
106   // Add a tree to the ensemble and returns a new tree_id.
107   int32 AddNewTree(const float weight, const int32_t logits_dimension);
108 
109   // Adds new tree with one node to the ensemble and sets node's value to logits
110   int32 AddNewTreeWithLogits(const float weight,
111                              const std::vector<float>& logits,
112                              const int32_t logits_dimension);
113 
114   // Grows the tree by adding a bucketized split and leaves.
115   void AddBucketizedSplitNode(
116       const int32_t tree_id,
117       const std::pair<int32, boosted_trees::SplitCandidate>& split_entry,
118       const int32_t logits_dimension, int32* left_node_id,
119       int32* right_node_id);
120 
121   // Grows the tree by adding a categorical split and leaves.
122   void AddCategoricalSplitNode(
123       const int32_t tree_id,
124       const std::pair<int32, boosted_trees::SplitCandidate>& split_entry,
125       const int32_t logits_dimension, int32* left_node_id,
126       int32* right_node_id);
127 
128   // Retrieves tree weights and returns as a vector.
129   // It involves a copy, so should be called only sparingly (like once per
130   // iteration, not per example).
131   std::vector<float> GetTreeWeights() const;
132 
133   float GetTreeWeight(const int32_t tree_id) const;
134 
135   float IsTreeFinalized(const int32_t tree_id) const;
136 
137   float IsTreePostPruned(const int32_t tree_id) const;
138 
139   void SetIsFinalized(const int32_t tree_id, const bool is_finalized);
140 
141   // Sets the weight of i'th tree.
142   void SetTreeWeight(const int32_t tree_id, const float weight);
143 
144   // Resets the resource and frees the protos in arena.
145   // Caller needs to hold the mutex lock while calling this.
146   virtual void Reset();
147 
148   void PostPruneTree(const int32_t current_tree,
149                      const int32_t logits_dimension);
150 
151   // For a given node, returns the id in a pruned tree, as well as correction
152   // to the cached prediction that should be applied. If tree was not
153   // post-pruned, current_node_id will be equal to initial_node_id and logit
154   // update will be equal to zero.
155   void GetPostPruneCorrection(const int32_t tree_id,
156                               const int32_t initial_node_id,
157                               int32* current_node_id,
158                               std::vector<float>* logit_updates) const;
get_mutex()159   mutex* get_mutex() { return &mu_; }
160 
161  private:
162   // Helper method to check whether a node is a terminal node in that it
163   // only has leaf nodes as children.
164   bool IsTerminalSplitNode(const int32_t tree_id, const int32_t node_id) const;
165 
166   // For each pruned node, finds the leaf where it finally ended up and
167   // calculates the total update from that pruned node prediction.
168   void CalculateParentAndLogitUpdate(
169       const int32_t start_node_id,
170       const std::vector<std::pair<int32, std::vector<float>>>& nodes_change,
171       int32* parent_id, std::vector<float>* change) const;
172 
173   // Helper method to collect the information to be used to prune some nodes in
174   // the tree.
175   void RecursivelyDoPostPrunePreparation(
176       const int32_t tree_id, const int32_t node_id,
177       std::vector<int32>* nodes_to_delete,
178       std::vector<std::pair<int32, std::vector<float>>>* nodes_meta);
179 
180  protected:
181   protobuf::Arena arena_;
182   mutex mu_;
183   boosted_trees::TreeEnsemble* tree_ensemble_;
184 
185   boosted_trees::Node* AddLeafNodes(
186       int32_t tree_id,
187       const std::pair<int32, boosted_trees::SplitCandidate>& split_entry,
188       const int32_t logits_dimension, int32* left_node_id,
189       int32* right_node_id);
190 };
191 
192 }  // namespace tensorflow
193 
194 #endif  // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_RESOURCES_H_
195