1 /*
2 * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "modules/audio_processing/agc2/rnn_vad/ring_buffer.h"
12
13 #include "test/gtest.h"
14
15 namespace webrtc {
16 namespace rnn_vad {
17 namespace {
18
19 // Compare the elements of two given array views.
20 template <typename T, std::ptrdiff_t S>
ExpectEq(rtc::ArrayView<const T,S> a,rtc::ArrayView<const T,S> b)21 void ExpectEq(rtc::ArrayView<const T, S> a, rtc::ArrayView<const T, S> b) {
22 for (int i = 0; i < S; ++i) {
23 SCOPED_TRACE(i);
24 EXPECT_EQ(a[i], b[i]);
25 }
26 }
27
28 // Test push/read sequences.
29 template <typename T, int S, int N>
TestRingBuffer()30 void TestRingBuffer() {
31 SCOPED_TRACE(N);
32 SCOPED_TRACE(S);
33 std::array<T, S> prev_pushed_array;
34 std::array<T, S> pushed_array;
35 rtc::ArrayView<const T, S> pushed_array_view(pushed_array.data(), S);
36
37 // Init.
38 RingBuffer<T, S, N> ring_buf;
39 ring_buf.GetArrayView(0);
40 pushed_array.fill(0);
41 ring_buf.Push(pushed_array_view);
42 ExpectEq(pushed_array_view, ring_buf.GetArrayView(0));
43
44 // Push N times and check most recent and second most recent.
45 for (T v = 1; v <= static_cast<T>(N); ++v) {
46 SCOPED_TRACE(v);
47 prev_pushed_array = pushed_array;
48 pushed_array.fill(v);
49 ring_buf.Push(pushed_array_view);
50 ExpectEq(pushed_array_view, ring_buf.GetArrayView(0));
51 if (N > 1) {
52 pushed_array.fill(v - 1);
53 ExpectEq(pushed_array_view, ring_buf.GetArrayView(1));
54 }
55 }
56
57 // Check buffer.
58 for (int delay = 2; delay < N; ++delay) {
59 SCOPED_TRACE(delay);
60 T expected_value = N - static_cast<T>(delay);
61 pushed_array.fill(expected_value);
62 ExpectEq(pushed_array_view, ring_buf.GetArrayView(delay));
63 }
64 }
65
66 // Check that for different delays, different views are returned.
TEST(RnnVadTest,RingBufferArrayViews)67 TEST(RnnVadTest, RingBufferArrayViews) {
68 constexpr int s = 3;
69 constexpr int n = 4;
70 RingBuffer<int, s, n> ring_buf;
71 std::array<int, s> pushed_array;
72 pushed_array.fill(1);
73 for (int k = 0; k <= n; ++k) { // Push data n + 1 times.
74 SCOPED_TRACE(k);
75 // Check array views.
76 for (int i = 0; i < n; ++i) {
77 SCOPED_TRACE(i);
78 auto view_i = ring_buf.GetArrayView(i);
79 for (int j = i + 1; j < n; ++j) {
80 SCOPED_TRACE(j);
81 auto view_j = ring_buf.GetArrayView(j);
82 EXPECT_NE(view_i, view_j);
83 }
84 }
85 ring_buf.Push(pushed_array);
86 }
87 }
88
TEST(RnnVadTest,RingBufferUnsigned)89 TEST(RnnVadTest, RingBufferUnsigned) {
90 TestRingBuffer<uint8_t, 1, 1>();
91 TestRingBuffer<uint8_t, 2, 5>();
92 TestRingBuffer<uint8_t, 5, 2>();
93 TestRingBuffer<uint8_t, 5, 5>();
94 }
95
TEST(RnnVadTest,RingBufferSigned)96 TEST(RnnVadTest, RingBufferSigned) {
97 TestRingBuffer<int, 1, 1>();
98 TestRingBuffer<int, 2, 5>();
99 TestRingBuffer<int, 5, 2>();
100 TestRingBuffer<int, 5, 5>();
101 }
102
TEST(RnnVadTest,RingBufferFloating)103 TEST(RnnVadTest, RingBufferFloating) {
104 TestRingBuffer<float, 1, 1>();
105 TestRingBuffer<float, 2, 5>();
106 TestRingBuffer<float, 5, 2>();
107 TestRingBuffer<float, 5, 5>();
108 }
109
110 } // namespace
111 } // namespace rnn_vad
112 } // namespace webrtc
113