1 // Copyright 2019 The Abseil Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 #ifndef ABSL_RANDOM_INTERNAL_UNIFORM_HELPER_H_
16 #define ABSL_RANDOM_INTERNAL_UNIFORM_HELPER_H_
17
18 #include <cmath>
19 #include <limits>
20 #include <type_traits>
21
22 #include "absl/meta/type_traits.h"
23
24 namespace absl {
25 ABSL_NAMESPACE_BEGIN
26 template <typename IntType>
27 class uniform_int_distribution;
28
29 template <typename RealType>
30 class uniform_real_distribution;
31
32 // Interval tag types which specify whether the interval is open or closed
33 // on either boundary.
34
35 namespace random_internal {
36 template <typename T>
37 struct TagTypeCompare {};
38
39 template <typename T>
40 constexpr bool operator==(TagTypeCompare<T>, TagTypeCompare<T>) {
41 // Tags are mono-states. They always compare equal.
42 return true;
43 }
44 template <typename T>
45 constexpr bool operator!=(TagTypeCompare<T>, TagTypeCompare<T>) {
46 return false;
47 }
48
49 } // namespace random_internal
50
51 struct IntervalClosedClosedTag
52 : public random_internal::TagTypeCompare<IntervalClosedClosedTag> {};
53 struct IntervalClosedOpenTag
54 : public random_internal::TagTypeCompare<IntervalClosedOpenTag> {};
55 struct IntervalOpenClosedTag
56 : public random_internal::TagTypeCompare<IntervalOpenClosedTag> {};
57 struct IntervalOpenOpenTag
58 : public random_internal::TagTypeCompare<IntervalOpenOpenTag> {};
59
60 namespace random_internal {
61 // The functions
62 // uniform_lower_bound(tag, a, b)
63 // and
64 // uniform_upper_bound(tag, a, b)
65 // are used as implementation-details for absl::Uniform().
66 //
67 // Conceptually,
68 // [a, b] == [uniform_lower_bound(IntervalClosedClosed, a, b),
69 // uniform_upper_bound(IntervalClosedClosed, a, b)]
70 // (a, b) == [uniform_lower_bound(IntervalOpenOpen, a, b),
71 // uniform_upper_bound(IntervalOpenOpen, a, b)]
72 // [a, b) == [uniform_lower_bound(IntervalClosedOpen, a, b),
73 // uniform_upper_bound(IntervalClosedOpen, a, b)]
74 // (a, b] == [uniform_lower_bound(IntervalOpenClosed, a, b),
75 // uniform_upper_bound(IntervalOpenClosed, a, b)]
76 //
77 template <typename IntType, typename Tag>
78 typename absl::enable_if_t<
79 absl::conjunction<
80 std::is_integral<IntType>,
81 absl::disjunction<std::is_same<Tag, IntervalOpenClosedTag>,
82 std::is_same<Tag, IntervalOpenOpenTag>>>::value,
83 IntType>
uniform_lower_bound(Tag,IntType a,IntType)84 uniform_lower_bound(Tag, IntType a, IntType) {
85 return a + 1;
86 }
87
88 template <typename FloatType, typename Tag>
89 typename absl::enable_if_t<
90 absl::conjunction<
91 std::is_floating_point<FloatType>,
92 absl::disjunction<std::is_same<Tag, IntervalOpenClosedTag>,
93 std::is_same<Tag, IntervalOpenOpenTag>>>::value,
94 FloatType>
uniform_lower_bound(Tag,FloatType a,FloatType b)95 uniform_lower_bound(Tag, FloatType a, FloatType b) {
96 return std::nextafter(a, b);
97 }
98
99 template <typename NumType, typename Tag>
100 typename absl::enable_if_t<
101 absl::disjunction<std::is_same<Tag, IntervalClosedClosedTag>,
102 std::is_same<Tag, IntervalClosedOpenTag>>::value,
103 NumType>
uniform_lower_bound(Tag,NumType a,NumType)104 uniform_lower_bound(Tag, NumType a, NumType) {
105 return a;
106 }
107
108 template <typename IntType, typename Tag>
109 typename absl::enable_if_t<
110 absl::conjunction<
111 std::is_integral<IntType>,
112 absl::disjunction<std::is_same<Tag, IntervalClosedOpenTag>,
113 std::is_same<Tag, IntervalOpenOpenTag>>>::value,
114 IntType>
uniform_upper_bound(Tag,IntType,IntType b)115 uniform_upper_bound(Tag, IntType, IntType b) {
116 return b - 1;
117 }
118
119 template <typename FloatType, typename Tag>
120 typename absl::enable_if_t<
121 absl::conjunction<
122 std::is_floating_point<FloatType>,
123 absl::disjunction<std::is_same<Tag, IntervalClosedOpenTag>,
124 std::is_same<Tag, IntervalOpenOpenTag>>>::value,
125 FloatType>
uniform_upper_bound(Tag,FloatType,FloatType b)126 uniform_upper_bound(Tag, FloatType, FloatType b) {
127 return b;
128 }
129
130 template <typename IntType, typename Tag>
131 typename absl::enable_if_t<
132 absl::conjunction<
133 std::is_integral<IntType>,
134 absl::disjunction<std::is_same<Tag, IntervalClosedClosedTag>,
135 std::is_same<Tag, IntervalOpenClosedTag>>>::value,
136 IntType>
uniform_upper_bound(Tag,IntType,IntType b)137 uniform_upper_bound(Tag, IntType, IntType b) {
138 return b;
139 }
140
141 template <typename FloatType, typename Tag>
142 typename absl::enable_if_t<
143 absl::conjunction<
144 std::is_floating_point<FloatType>,
145 absl::disjunction<std::is_same<Tag, IntervalClosedClosedTag>,
146 std::is_same<Tag, IntervalOpenClosedTag>>>::value,
147 FloatType>
uniform_upper_bound(Tag,FloatType,FloatType b)148 uniform_upper_bound(Tag, FloatType, FloatType b) {
149 return std::nextafter(b, (std::numeric_limits<FloatType>::max)());
150 }
151
152 template <typename NumType>
153 using UniformDistribution =
154 typename std::conditional<std::is_integral<NumType>::value,
155 absl::uniform_int_distribution<NumType>,
156 absl::uniform_real_distribution<NumType>>::type;
157
158 template <typename NumType>
159 struct UniformDistributionWrapper : public UniformDistribution<NumType> {
160 template <typename TagType>
UniformDistributionWrapperUniformDistributionWrapper161 explicit UniformDistributionWrapper(TagType, NumType lo, NumType hi)
162 : UniformDistribution<NumType>(
163 uniform_lower_bound<NumType>(TagType{}, lo, hi),
164 uniform_upper_bound<NumType>(TagType{}, lo, hi)) {}
165
UniformDistributionWrapperUniformDistributionWrapper166 explicit UniformDistributionWrapper(NumType lo, NumType hi)
167 : UniformDistribution<NumType>(
168 uniform_lower_bound<NumType>(IntervalClosedOpenTag(), lo, hi),
169 uniform_upper_bound<NumType>(IntervalClosedOpenTag(), lo, hi)) {}
170
UniformDistributionWrapperUniformDistributionWrapper171 explicit UniformDistributionWrapper()
172 : UniformDistribution<NumType>(std::numeric_limits<NumType>::lowest(),
173 (std::numeric_limits<NumType>::max)()) {}
174 };
175
176 } // namespace random_internal
177 ABSL_NAMESPACE_END
178 } // namespace absl
179
180 #endif // ABSL_RANDOM_INTERNAL_UNIFORM_HELPER_H_
181