• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #pragma once
10 
11 #include <gtest/gtest.h>
12 
13 #include <algorithm>
14 #include <cassert>
15 #include <cstddef>
16 #include <cstdlib>
17 #include <functional>
18 #include <limits>
19 #include <random>
20 #include <vector>
21 
22 #include <xnnpack/params.h>
23 
24 
25 class LUTNormMicrokernelTester {
26  public:
n(size_t n)27   inline LUTNormMicrokernelTester& n(size_t n) {
28     assert(n != 0);
29     this->n_ = n;
30     return *this;
31   }
32 
n()33   inline size_t n() const {
34     return this->n_;
35   }
36 
inplace(bool inplace)37   inline LUTNormMicrokernelTester& inplace(bool inplace) {
38     this->inplace_ = inplace;
39     return *this;
40   }
41 
inplace()42   inline bool inplace() const {
43     return this->inplace_;
44   }
45 
iterations(size_t iterations)46   inline LUTNormMicrokernelTester& iterations(size_t iterations) {
47     this->iterations_ = iterations;
48     return *this;
49   }
50 
iterations()51   inline size_t iterations() const {
52     return this->iterations_;
53   }
54 
Test(xnn_u8_lut32norm_ukernel_function lutnorm)55   void Test(xnn_u8_lut32norm_ukernel_function lutnorm) const {
56     std::random_device random_device;
57     auto rng = std::mt19937(random_device());
58     auto u8rng = std::bind(std::uniform_int_distribution<uint32_t>(0, std::numeric_limits<uint8_t>::max()), rng);
59     auto u32rng = std::bind(
60       std::uniform_int_distribution<uint32_t>(1, std::numeric_limits<uint32_t>::max() / (257 * n())),
61       rng);
62 
63     std::vector<uint8_t> x(n());
64     std::vector<uint32_t> t(256);
65     std::vector<uint8_t> y(n());
66     std::vector<float> y_ref(n());
67     for (size_t iteration = 0; iteration < iterations(); iteration++) {
68       std::generate(x.begin(), x.end(), std::ref(u8rng));
69       std::generate(t.begin(), t.end(), std::ref(u32rng));
70       if (inplace()) {
71         std::generate(y.begin(), y.end(), std::ref(u8rng));
72       } else {
73         std::fill(y.begin(), y.end(), 0xA5);
74       }
75       const uint8_t* x_data = inplace() ? y.data() : x.data();
76 
77       // Compute reference results.
78       uint32_t sum = 0;
79       for (size_t i = 0; i < n(); i++) {
80         sum += t[x_data[i]];
81       }
82       for (size_t i = 0; i < n(); i++) {
83         y_ref[i] = 256.0f * float(t[x_data[i]]) / float(sum);
84         y_ref[i] = std::min(y_ref[i], 255.0f);
85       }
86 
87       // Call optimized micro-kernel.
88       lutnorm(n(), x_data, t.data(), y.data());
89 
90       // Verify results.
91       for (size_t i = 0; i < n(); i++) {
92         ASSERT_NEAR(y_ref[i], float(y[i]), 0.5f)
93           << "at position " << i << ", n = " << n() << ", sum = " << sum;
94       }
95     }
96   }
97 
98  private:
99   size_t n_{1};
100   bool inplace_{false};
101   size_t iterations_{15};
102 };
103