1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // An abstraction to pick from one of N elements with a specified
17 // weight per element.
18 //
19 // The weight for a given element can be changed in O(lg N) time
20 // An element can be picked in O(lg N) time.
21 //
22 // Uses O(N) bytes of memory.
23 //
24 // Alternative: distribution-sampler.h allows O(1) time picking, but no weight
25 // adjustment after construction.
26
27 #ifndef TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_
28 #define TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_
29
30 #include <assert.h>
31
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/macros.h"
34 #include "tensorflow/core/platform/types.h"
35
36 namespace tensorflow {
37 namespace random {
38
39 class SimplePhilox;
40
41 class WeightedPicker {
42 public:
43 // REQUIRES N >= 0
44 // Initializes the elements with a weight of one per element
45 explicit WeightedPicker(int N);
46
47 // Releases all resources
48 ~WeightedPicker();
49
50 // Pick a random element with probability proportional to its weight.
51 // If total weight is zero, returns -1.
52 int Pick(SimplePhilox* rnd) const;
53
54 // Deterministically pick element x whose weight covers the
55 // specified weight_index.
56 // Returns -1 if weight_index is not in the range [ 0 .. total_weight()-1 ]
57 int PickAt(int32 weight_index) const;
58
59 // Get the weight associated with an element
60 // REQUIRES 0 <= index < N
61 int32 get_weight(int index) const;
62
63 // Set the weight associated with an element
64 // REQUIRES weight >= 0.0f
65 // REQUIRES 0 <= index < N
66 void set_weight(int index, int32 weight);
67
68 // Get the total combined weight of all elements
69 int32 total_weight() const;
70
71 // Get the number of elements in the picker
72 int num_elements() const;
73
74 // Set weight of each element to "weight"
75 void SetAllWeights(int32 weight);
76
77 // Resizes the picker to N and
78 // sets the weight of each element i to weight[i].
79 // The sum of the weights should not exceed 2^31 - 2
80 // Complexity O(N).
81 void SetWeightsFromArray(int N, const int32* weights);
82
83 // REQUIRES N >= 0
84 //
85 // Resize the weighted picker so that it has "N" elements.
86 // Any newly added entries have zero weight.
87 //
88 // Note: Resizing to a smaller size than num_elements() will
89 // not reclaim any memory. If you wish to reduce memory usage,
90 // allocate a new WeightedPicker of the appropriate size.
91 //
92 // It is efficient to use repeated calls to Resize(num_elements() + 1)
93 // to grow the picker to size X (takes total time O(X)).
94 void Resize(int N);
95
96 // Grow the picker by one and set the weight of the new entry to "weight".
97 //
98 // Repeated calls to Append() in order to grow the
99 // picker to size X takes a total time of O(X lg(X)).
100 // Consider using SetWeightsFromArray instead.
101 void Append(int32 weight);
102
103 private:
104 // We keep a binary tree with N leaves. The "i"th leaf contains
105 // the weight of the "i"th element. An internal node contains
106 // the sum of the weights of its children.
107 int N_; // Number of elements
108 int num_levels_; // Number of levels in tree (level-0 is root)
109 int32** level_; // Array that holds nodes per level
110
111 // Size of each level
LevelSize(int level)112 static int LevelSize(int level) { return 1 << level; }
113
114 // Rebuild the tree weights using the leaf weights
115 void RebuildTreeWeights();
116
117 TF_DISALLOW_COPY_AND_ASSIGN(WeightedPicker);
118 };
119
get_weight(int index)120 inline int32 WeightedPicker::get_weight(int index) const {
121 DCHECK_GE(index, 0);
122 DCHECK_LT(index, N_);
123 return level_[num_levels_ - 1][index];
124 }
125
total_weight()126 inline int32 WeightedPicker::total_weight() const { return level_[0][0]; }
127
num_elements()128 inline int WeightedPicker::num_elements() const { return N_; }
129
130 } // namespace random
131 } // namespace tensorflow
132
133 #endif // TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_
134