1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include "tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h"
16
17 #include <iterator>
18 #include <numeric>
19 #include <unordered_set>
20
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/random/philox_random.h"
23 #include "tensorflow/core/lib/random/simple_philox.h"
24 #include "tensorflow/core/platform/logging.h"
25
26 using tensorflow::Status;
27 using tensorflow::boosted_trees::learner::LearningRateDropoutDrivenConfig;
28 using tensorflow::random::PhiloxRandom;
29 using tensorflow::random::SimplePhilox;
30
31 namespace tensorflow {
32 namespace boosted_trees {
33 namespace utils {
34
DropOutTrees(const uint64 seed,const LearningRateDropoutDrivenConfig & config,const std::unordered_set<int32> & trees_not_to_drop,const std::vector<float> & weights,std::vector<int32> * dropped_trees,std::vector<float> * original_weights)35 Status DropoutUtils::DropOutTrees(
36 const uint64 seed, const LearningRateDropoutDrivenConfig& config,
37 const std::unordered_set<int32>& trees_not_to_drop,
38 const std::vector<float>& weights, std::vector<int32>* dropped_trees,
39 std::vector<float>* original_weights) {
40 // Verify params.
41 if (dropped_trees == nullptr) {
42 return errors::Internal("Dropped trees is nullptr.");
43 }
44 if (original_weights == nullptr) {
45 return errors::InvalidArgument("Original weights is nullptr.");
46 }
47 const float dropout_probability = config.dropout_probability();
48 if (dropout_probability < 0 || dropout_probability > 1) {
49 return errors::InvalidArgument(
50 "Dropout probability must be in [0,1] range");
51 }
52 const float probability_of_skipping_dropout =
53 config.probability_of_skipping_dropout();
54 if (probability_of_skipping_dropout < 0 ||
55 probability_of_skipping_dropout > 1) {
56 return errors::InvalidArgument(
57 "Probability of skipping dropout must be in [0,1] range");
58 }
59 const auto num_trees = weights.size();
60
61 dropped_trees->clear();
62 original_weights->clear();
63
64 // If dropout is no op, return.
65 if (dropout_probability == 0 || probability_of_skipping_dropout == 1.0) {
66 return Status::OK();
67 }
68
69 // Roll the dice for each tree.
70 PhiloxRandom philox(seed);
71 SimplePhilox rng(&philox);
72
73 std::vector<int32> trees_to_keep;
74
75 // What is the probability of skipping dropout altogether.
76 if (probability_of_skipping_dropout != 0) {
77 // First roll the dice - do we do dropout
78 double roll = rng.RandDouble();
79 if (roll < probability_of_skipping_dropout) {
80 // don't do dropout
81 return Status::OK();
82 }
83 }
84
85 for (int32 i = 0; i < num_trees; ++i) {
86 // We can't drop some of the trees: for example, bias tree in batch mode,
87 // or current tree that is built, in the batch mode.
88 if (trees_not_to_drop.find(i) != trees_not_to_drop.end()) {
89 continue;
90 }
91 double roll = rng.RandDouble();
92 if (roll >= dropout_probability) {
93 trees_to_keep.push_back(i);
94 } else {
95 dropped_trees->push_back(i);
96 }
97 }
98
99 // Sort the dropped trees indices.
100 std::sort(dropped_trees->begin(), dropped_trees->end());
101 for (const int32 dropped_tree : *dropped_trees) {
102 original_weights->push_back(weights[dropped_tree]);
103 }
104
105 return Status::OK();
106 }
107
GetTreesWeightsForAddingTrees(const std::vector<int32> & dropped_trees,const std::vector<float> & dropped_trees_original_weights,const int32 new_trees_first_index,const int32 num_trees_to_add,std::vector<float> * current_weights,std::vector<int32> * num_updates)108 void DropoutUtils::GetTreesWeightsForAddingTrees(
109 const std::vector<int32>& dropped_trees,
110 const std::vector<float>& dropped_trees_original_weights,
111 const int32 new_trees_first_index, const int32 num_trees_to_add,
112 std::vector<float>* current_weights, std::vector<int32>* num_updates) {
113 CHECK(num_updates->size() == current_weights->size());
114 // combined weight of trees that were dropped out
115
116 const float dropped_sum =
117 std::accumulate(dropped_trees_original_weights.begin(),
118 dropped_trees_original_weights.end(), 0.0);
119
120 const int num_dropped = dropped_trees.size();
121
122 // Allocate additional weight for the new tree
123 const float total_new_trees_weight = dropped_sum / (num_dropped + 1);
124
125 for (int i = 0; i < num_trees_to_add; ++i) {
126 const int32 new_tree_index = new_trees_first_index + i;
127 if (new_tree_index < current_weights->size()) {
128 // We have the entries in weights and updates for this tree already
129 (*current_weights)[new_tree_index] =
130 total_new_trees_weight / num_trees_to_add;
131 (*num_updates)[new_tree_index]++;
132 } else {
133 // We need to add a new entry. This is non-batch mode.
134 current_weights->push_back(total_new_trees_weight / num_trees_to_add);
135 num_updates->push_back(1);
136 }
137 }
138
139 for (int32 i = 0; i < dropped_trees.size(); ++i) {
140 const int32 dropped = dropped_trees[i];
141 const float original_weight = dropped_trees_original_weights[i];
142 const float new_weight = original_weight * num_dropped / (num_dropped + 1);
143 (*current_weights)[dropped] = new_weight;
144 // Update the number of updates per tree.
145 ++(*num_updates)[dropped];
146 }
147 }
148
149 } // namespace utils
150 } // namespace boosted_trees
151 } // namespace tensorflow
152