• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2018 The Abseil Authors.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      https://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 #ifndef ABSL_RANDOM_INTERNAL_MOCKING_BIT_GEN_BASE_H_
17 #define ABSL_RANDOM_INTERNAL_MOCKING_BIT_GEN_BASE_H_
18 
19 #include <atomic>
20 #include <deque>
21 #include <string>
22 #include <typeinfo>
23 
24 #include "absl/random/random.h"
25 #include "absl/strings/str_cat.h"
26 
27 namespace absl {
28 ABSL_NAMESPACE_BEGIN
29 namespace random_internal {
30 
31 // MockingBitGenExpectationFormatter is invoked to format unsatisfied mocks
32 // and remaining results into a description string.
33 template <typename DistrT, typename FormatT>
34 struct MockingBitGenExpectationFormatter {
operatorMockingBitGenExpectationFormatter35   std::string operator()(absl::string_view args) {
36     return absl::StrCat(FormatT::FunctionName(), "(", args, ")");
37   }
38 };
39 
40 // MockingBitGenCallFormatter is invoked to format each distribution call
41 // into a description string for the mock log.
42 template <typename DistrT, typename FormatT>
43 struct MockingBitGenCallFormatter {
operatorMockingBitGenCallFormatter44   std::string operator()(const DistrT& dist,
45                          const typename DistrT::result_type& result) {
46     return absl::StrCat(
47         FormatT::FunctionName(), "(", FormatT::FormatArgs(dist), ") => {",
48         FormatT::FormatResults(absl::MakeSpan(&result, 1)), "}");
49   }
50 };
51 
52 class MockingBitGenBase {
53   template <typename>
54   friend struct DistributionCaller;
55   using generator_type = absl::BitGen;
56 
57  public:
58   // URBG interface
59   using result_type = generator_type::result_type;
result_type(min)60   static constexpr result_type(min)() { return (generator_type::min)(); }
result_type(max)61   static constexpr result_type(max)() { return (generator_type::max)(); }
operator()62   result_type operator()() { return gen_(); }
63 
MockingBitGenBase()64   MockingBitGenBase() : gen_(), observed_call_log_() {}
65   virtual ~MockingBitGenBase() = default;
66 
67  protected:
observed_call_log()68   const std::deque<std::string>& observed_call_log() {
69     return observed_call_log_;
70   }
71 
72   // CallImpl is the type-erased virtual dispatch.
73   // The type of dist is always distribution<T>,
74   // The type of result is always distribution<T>::result_type.
75   virtual bool CallImpl(const std::type_info& distr_type, void* dist_args,
76                         void* result) = 0;
77 
78   template <typename DistrT, typename ArgTupleT>
GetTypeId()79   static const std::type_info& GetTypeId() {
80     return typeid(std::pair<absl::decay_t<DistrT>, absl::decay_t<ArgTupleT>>);
81   }
82 
83   // Call the generating distribution function.
84   // Invoked by DistributionCaller<>::Call<DistT, FormatT>.
85   // DistT is the distribution type.
86   // FormatT is the distribution formatter traits type.
87   template <typename DistrT, typename FormatT, typename... Args>
Call(Args &&...args)88   typename DistrT::result_type Call(Args&&... args) {
89     using distr_result_type = typename DistrT::result_type;
90     using ArgTupleT = std::tuple<absl::decay_t<Args>...>;
91 
92     ArgTupleT arg_tuple(std::forward<Args>(args)...);
93     auto dist = absl::make_from_tuple<DistrT>(arg_tuple);
94 
95     distr_result_type result{};
96     bool found_match =
97         CallImpl(GetTypeId<DistrT, ArgTupleT>(), &arg_tuple, &result);
98 
99     if (!found_match) {
100       result = dist(gen_);
101     }
102 
103     // TODO(asoffer): Forwarding the args through means we no longer need to
104     // extract them from the from the distribution in formatter traits. We can
105     // just StrJoin them.
106     observed_call_log_.push_back(
107         MockingBitGenCallFormatter<DistrT, FormatT>{}(dist, result));
108     return result;
109   }
110 
111  private:
112   generator_type gen_;
113   std::deque<std::string> observed_call_log_;
114 };  // namespace random_internal
115 
116 }  // namespace random_internal
117 ABSL_NAMESPACE_END
118 }  // namespace absl
119 
120 #endif  // ABSL_RANDOM_INTERNAL_MOCKING_BIT_GEN_BASE_H_
121