• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #pragma once
10 
11 #include <gtest/gtest.h>
12 
13 #include <algorithm>
14 #include <cassert>
15 #include <cstddef>
16 #include <cstdlib>
17 #include <functional>
18 #include <random>
19 #include <vector>
20 
21 #include <xnnpack/params.h>
22 
23 
24 class LUTMicrokernelTester {
25  public:
n(size_t n)26   inline LUTMicrokernelTester& n(size_t n) {
27     assert(n != 0);
28     this->n_ = n;
29     return *this;
30   }
31 
n()32   inline size_t n() const {
33     return this->n_;
34   }
35 
inplace(bool inplace)36   inline LUTMicrokernelTester& inplace(bool inplace) {
37     this->inplace_ = inplace;
38     return *this;
39   }
40 
inplace()41   inline bool inplace() const {
42     return this->inplace_;
43   }
44 
iterations(size_t iterations)45   inline LUTMicrokernelTester& iterations(size_t iterations) {
46     this->iterations_ = iterations;
47     return *this;
48   }
49 
iterations()50   inline size_t iterations() const {
51     return this->iterations_;
52   }
53 
Test(xnn_x8_lut_ukernel_function lut)54   void Test(xnn_x8_lut_ukernel_function lut) const {
55     std::random_device random_device;
56     auto rng = std::mt19937(random_device());
57     auto u8rng = std::bind(std::uniform_int_distribution<uint8_t>(), rng);
58 
59     std::vector<uint8_t> x(n());
60     std::vector<uint8_t> t(256);
61     std::vector<uint8_t> y(n());
62     std::vector<uint8_t> y_ref(n());
63     for (size_t iteration = 0; iteration < iterations(); iteration++) {
64       std::generate(x.begin(), x.end(), std::ref(u8rng));
65       std::generate(t.begin(), t.end(), std::ref(u8rng));
66       if (inplace()) {
67         std::generate(y.begin(), y.end(), std::ref(u8rng));
68       } else {
69         std::fill(y.begin(), y.end(), 0xA5);
70       }
71       const uint8_t* x_data = inplace() ? y.data() : x.data();
72 
73       // Compute reference results.
74       for (size_t i = 0; i < n(); i++) {
75         y_ref[i] = t[x_data[i]];
76       }
77 
78       // Call optimized micro-kernel.
79       lut(n(), x_data, t.data(), y.data());
80 
81       // Verify results.
82       for (size_t i = 0; i < n(); i++) {
83         ASSERT_EQ(uint32_t(y_ref[i]), uint32_t(y[i]))
84           << "at position " << i << ", n = " << n();
85       }
86     }
87   }
88 
89  private:
90   size_t n_{1};
91   bool inplace_{false};
92   size_t iterations_{15};
93 };
94