• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 The Dawn Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef TESTS_PARAMGENERATOR_H_
16 #define TESTS_PARAMGENERATOR_H_
17 
18 #include <tuple>
19 #include <vector>
20 
21 // ParamStruct is a custom struct which ParamStruct will yield when iterating.
22 // The types Params... should be the same as the types passed to the constructor
23 // of ParamStruct.
24 template <typename ParamStruct, typename... Params>
25 class ParamGenerator {
26     using ParamTuple = std::tuple<std::vector<Params>...>;
27     using Index = std::array<size_t, sizeof...(Params)>;
28 
29     static constexpr auto s_indexSequence = std::make_index_sequence<sizeof...(Params)>{};
30 
31     // Using an N-dimensional Index, extract params from ParamTuple and pass
32     // them to the constructor of ParamStruct.
33     template <size_t... Is>
GetParam(const ParamTuple & params,const Index & index,std::index_sequence<Is...>)34     static ParamStruct GetParam(const ParamTuple& params,
35                                 const Index& index,
36                                 std::index_sequence<Is...>) {
37         return ParamStruct(std::get<Is>(params)[std::get<Is>(index)]...);
38     }
39 
40     // Get the last value index into a ParamTuple.
41     template <size_t... Is>
GetLastIndex(const ParamTuple & params,std::index_sequence<Is...>)42     static Index GetLastIndex(const ParamTuple& params, std::index_sequence<Is...>) {
43         return Index{std::get<Is>(params).size() - 1 ...};
44     }
45 
46   public:
47     using value_type = ParamStruct;
48 
ParamGenerator(std::vector<Params>...params)49     ParamGenerator(std::vector<Params>... params) : mParams(params...), mIsEmpty(false) {
50         for (bool isEmpty : {params.empty()...}) {
51             mIsEmpty |= isEmpty;
52         }
53     }
54 
55     class Iterator : public std::iterator<std::forward_iterator_tag, ParamStruct, size_t> {
56       public:
57         Iterator& operator++() {
58             // Increment the Index by 1. If the i'th place reaches the maximum,
59             // reset it to 0 and continue with the i+1'th place.
60             for (int i = mIndex.size() - 1; i >= 0; --i) {
61                 if (mIndex[i] >= mLastIndex[i]) {
62                     mIndex[i] = 0;
63                 } else {
64                     mIndex[i]++;
65                     return *this;
66                 }
67             }
68 
69             // Set a marker that the iterator has reached the end.
70             mEnd = true;
71             return *this;
72         }
73 
74         bool operator==(const Iterator& other) const {
75             return mEnd == other.mEnd && mIndex == other.mIndex;
76         }
77 
78         bool operator!=(const Iterator& other) const {
79             return !(*this == other);
80         }
81 
82         ParamStruct operator*() const {
83             return GetParam(mParams, mIndex, s_indexSequence);
84         }
85 
86       private:
87         friend class ParamGenerator;
88 
Iterator(ParamTuple params,Index index)89         Iterator(ParamTuple params, Index index)
90             : mParams(params), mIndex(index), mLastIndex{GetLastIndex(params, s_indexSequence)} {
91         }
92 
93         ParamTuple mParams;
94         Index mIndex;
95         Index mLastIndex;
96         bool mEnd = false;
97     };
98 
begin()99     Iterator begin() const {
100         if (mIsEmpty) {
101             return end();
102         }
103         return Iterator(mParams, {});
104     }
105 
end()106     Iterator end() const {
107         Iterator iter(mParams, GetLastIndex(mParams, s_indexSequence));
108         ++iter;
109         return iter;
110     }
111 
112   private:
113     ParamTuple mParams;
114     bool mIsEmpty;
115 };
116 
117 struct BackendTestConfig;
118 struct AdapterTestParam;
119 
120 namespace detail {
121     std::vector<AdapterTestParam> GetAvailableAdapterTestParamsForBackends(
122         const BackendTestConfig* params,
123         size_t numParams);
124 }
125 
126 template <typename Param, typename... Params>
MakeParamGenerator(std::vector<BackendTestConfig> && first,std::initializer_list<Params> &&...params)127 auto MakeParamGenerator(std::vector<BackendTestConfig>&& first,
128                         std::initializer_list<Params>&&... params) {
129     return ParamGenerator<Param, AdapterTestParam, Params...>(
130         ::detail::GetAvailableAdapterTestParamsForBackends(first.data(), first.size()),
131         std::forward<std::initializer_list<Params>&&>(params)...);
132 }
133 template <typename Param, typename... Params>
MakeParamGenerator(std::vector<BackendTestConfig> && first,std::vector<Params> &&...params)134 auto MakeParamGenerator(std::vector<BackendTestConfig>&& first, std::vector<Params>&&... params) {
135     return ParamGenerator<Param, AdapterTestParam, Params...>(
136         ::detail::GetAvailableAdapterTestParamsForBackends(first.data(), first.size()),
137         std::forward<std::vector<Params>&&>(params)...);
138 }
139 
140 #endif  // TESTS_PARAMGENERATOR_H_
141