• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/rtp_rtcp/source/source_tracker.h"
12 
13 #include <algorithm>
14 #include <list>
15 #include <random>
16 #include <set>
17 #include <tuple>
18 #include <utility>
19 #include <vector>
20 
21 #include "absl/types/optional.h"
22 #include "api/rtp_headers.h"
23 #include "api/rtp_packet_info.h"
24 #include "api/rtp_packet_infos.h"
25 #include "test/gmock.h"
26 #include "test/gtest.h"
27 
28 namespace webrtc {
29 namespace {
30 
31 using ::testing::Combine;
32 using ::testing::ElementsAre;
33 using ::testing::ElementsAreArray;
34 using ::testing::IsEmpty;
35 using ::testing::TestWithParam;
36 using ::testing::Values;
37 
38 constexpr size_t kPacketInfosCountMax = 5;
39 
40 // Simple "guaranteed to be correct" re-implementation of |SourceTracker| for
41 // dual-implementation testing purposes.
42 class ExpectedSourceTracker {
43  public:
ExpectedSourceTracker(Clock * clock)44   explicit ExpectedSourceTracker(Clock* clock) : clock_(clock) {}
45 
OnFrameDelivered(const RtpPacketInfos & packet_infos)46   void OnFrameDelivered(const RtpPacketInfos& packet_infos) {
47     const int64_t now_ms = clock_->TimeInMilliseconds();
48 
49     for (const auto& packet_info : packet_infos) {
50       RtpSource::Extensions extensions = {packet_info.audio_level(),
51                                           packet_info.absolute_capture_time()};
52 
53       for (const auto& csrc : packet_info.csrcs()) {
54         entries_.emplace_front(now_ms, csrc, RtpSourceType::CSRC,
55                                packet_info.rtp_timestamp(), extensions);
56       }
57 
58       entries_.emplace_front(now_ms, packet_info.ssrc(), RtpSourceType::SSRC,
59                              packet_info.rtp_timestamp(), extensions);
60     }
61 
62     PruneEntries(now_ms);
63   }
64 
GetSources() const65   std::vector<RtpSource> GetSources() const {
66     PruneEntries(clock_->TimeInMilliseconds());
67 
68     return std::vector<RtpSource>(entries_.begin(), entries_.end());
69   }
70 
71  private:
PruneEntries(int64_t now_ms) const72   void PruneEntries(int64_t now_ms) const {
73     const int64_t prune_ms = now_ms - 10000;  // 10 seconds
74 
75     std::set<std::pair<RtpSourceType, uint32_t>> seen;
76 
77     auto it = entries_.begin();
78     auto end = entries_.end();
79     while (it != end) {
80       auto next = it;
81       ++next;
82 
83       auto key = std::make_pair(it->source_type(), it->source_id());
84       if (!seen.insert(key).second || it->timestamp_ms() < prune_ms) {
85         entries_.erase(it);
86       }
87 
88       it = next;
89     }
90   }
91 
92   Clock* const clock_;
93 
94   mutable std::list<RtpSource> entries_;
95 };
96 
97 class SourceTrackerRandomTest
98     : public TestWithParam<std::tuple<uint32_t, uint32_t>> {
99  protected:
SourceTrackerRandomTest()100   SourceTrackerRandomTest()
101       : ssrcs_count_(std::get<0>(GetParam())),
102         csrcs_count_(std::get<1>(GetParam())),
103         generator_(42) {}
104 
GeneratePacketInfos()105   RtpPacketInfos GeneratePacketInfos() {
106     size_t count = std::uniform_int_distribution<size_t>(
107         1, kPacketInfosCountMax)(generator_);
108 
109     RtpPacketInfos::vector_type packet_infos;
110     for (size_t i = 0; i < count; ++i) {
111       packet_infos.emplace_back(GenerateSsrc(), GenerateCsrcs(),
112                                 GenerateRtpTimestamp(), GenerateAudioLevel(),
113                                 GenerateAbsoluteCaptureTime(),
114                                 GenerateReceiveTimeMs());
115     }
116 
117     return RtpPacketInfos(std::move(packet_infos));
118   }
119 
GenerateClockAdvanceTimeMilliseconds()120   int64_t GenerateClockAdvanceTimeMilliseconds() {
121     double roll = std::uniform_real_distribution<double>(0.0, 1.0)(generator_);
122 
123     if (roll < 0.05) {
124       return 0;
125     }
126 
127     if (roll < 0.08) {
128       return SourceTracker::kTimeoutMs - 1;
129     }
130 
131     if (roll < 0.11) {
132       return SourceTracker::kTimeoutMs;
133     }
134 
135     if (roll < 0.19) {
136       return std::uniform_int_distribution<int64_t>(
137           SourceTracker::kTimeoutMs,
138           SourceTracker::kTimeoutMs * 1000)(generator_);
139     }
140 
141     return std::uniform_int_distribution<int64_t>(
142         1, SourceTracker::kTimeoutMs - 1)(generator_);
143   }
144 
145  private:
GenerateSsrc()146   uint32_t GenerateSsrc() {
147     return std::uniform_int_distribution<uint32_t>(1, ssrcs_count_)(generator_);
148   }
149 
GenerateCsrcs()150   std::vector<uint32_t> GenerateCsrcs() {
151     std::vector<uint32_t> csrcs;
152     for (size_t i = 1; i <= csrcs_count_ && csrcs.size() < kRtpCsrcSize; ++i) {
153       if (std::bernoulli_distribution(0.5)(generator_)) {
154         csrcs.push_back(i);
155       }
156     }
157 
158     return csrcs;
159   }
160 
GenerateRtpTimestamp()161   uint32_t GenerateRtpTimestamp() {
162     return std::uniform_int_distribution<uint32_t>()(generator_);
163   }
164 
GenerateAudioLevel()165   absl::optional<uint8_t> GenerateAudioLevel() {
166     if (std::bernoulli_distribution(0.25)(generator_)) {
167       return absl::nullopt;
168     }
169 
170     // Workaround for std::uniform_int_distribution<uint8_t> not being allowed.
171     return static_cast<uint8_t>(
172         std::uniform_int_distribution<uint16_t>()(generator_));
173   }
174 
GenerateAbsoluteCaptureTime()175   absl::optional<AbsoluteCaptureTime> GenerateAbsoluteCaptureTime() {
176     if (std::bernoulli_distribution(0.25)(generator_)) {
177       return absl::nullopt;
178     }
179 
180     AbsoluteCaptureTime value;
181 
182     value.absolute_capture_timestamp =
183         std::uniform_int_distribution<uint64_t>()(generator_);
184 
185     if (std::bernoulli_distribution(0.5)(generator_)) {
186       value.estimated_capture_clock_offset = absl::nullopt;
187     } else {
188       value.estimated_capture_clock_offset =
189           std::uniform_int_distribution<int64_t>()(generator_);
190     }
191 
192     return value;
193   }
194 
GenerateReceiveTimeMs()195   int64_t GenerateReceiveTimeMs() {
196     return std::uniform_int_distribution<int64_t>()(generator_);
197   }
198 
199   const uint32_t ssrcs_count_;
200   const uint32_t csrcs_count_;
201 
202   std::mt19937 generator_;
203 };
204 
205 }  // namespace
206 
TEST_P(SourceTrackerRandomTest,RandomOperations)207 TEST_P(SourceTrackerRandomTest, RandomOperations) {
208   constexpr size_t kIterationsCount = 200;
209 
210   SimulatedClock clock(1000000000000ULL);
211   SourceTracker actual_tracker(&clock);
212   ExpectedSourceTracker expected_tracker(&clock);
213 
214   ASSERT_THAT(actual_tracker.GetSources(), IsEmpty());
215   ASSERT_THAT(expected_tracker.GetSources(), IsEmpty());
216 
217   for (size_t i = 0; i < kIterationsCount; ++i) {
218     RtpPacketInfos packet_infos = GeneratePacketInfos();
219 
220     actual_tracker.OnFrameDelivered(packet_infos);
221     expected_tracker.OnFrameDelivered(packet_infos);
222 
223     clock.AdvanceTimeMilliseconds(GenerateClockAdvanceTimeMilliseconds());
224 
225     ASSERT_THAT(actual_tracker.GetSources(),
226                 ElementsAreArray(expected_tracker.GetSources()));
227   }
228 }
229 
230 INSTANTIATE_TEST_SUITE_P(All,
231                          SourceTrackerRandomTest,
232                          Combine(/*ssrcs_count_=*/Values(1, 2, 4),
233                                  /*csrcs_count_=*/Values(0, 1, 3, 7)));
234 
TEST(SourceTrackerTest,StartEmpty)235 TEST(SourceTrackerTest, StartEmpty) {
236   SimulatedClock clock(1000000000000ULL);
237   SourceTracker tracker(&clock);
238 
239   EXPECT_THAT(tracker.GetSources(), IsEmpty());
240 }
241 
TEST(SourceTrackerTest,OnFrameDeliveredRecordsSources)242 TEST(SourceTrackerTest, OnFrameDeliveredRecordsSources) {
243   constexpr uint32_t kSsrc = 10;
244   constexpr uint32_t kCsrcs0 = 20;
245   constexpr uint32_t kCsrcs1 = 21;
246   constexpr uint32_t kRtpTimestamp = 40;
247   constexpr absl::optional<uint8_t> kAudioLevel = 50;
248   constexpr absl::optional<AbsoluteCaptureTime> kAbsoluteCaptureTime =
249       AbsoluteCaptureTime{/*absolute_capture_timestamp=*/12,
250                           /*estimated_capture_clock_offset=*/absl::nullopt};
251   constexpr int64_t kReceiveTimeMs = 60;
252 
253   SimulatedClock clock(1000000000000ULL);
254   SourceTracker tracker(&clock);
255 
256   tracker.OnFrameDelivered(RtpPacketInfos(
257       {RtpPacketInfo(kSsrc, {kCsrcs0, kCsrcs1}, kRtpTimestamp, kAudioLevel,
258                      kAbsoluteCaptureTime, kReceiveTimeMs)}));
259 
260   int64_t timestamp_ms = clock.TimeInMilliseconds();
261   constexpr RtpSource::Extensions extensions = {kAudioLevel,
262                                                 kAbsoluteCaptureTime};
263 
264   EXPECT_THAT(tracker.GetSources(),
265               ElementsAre(RtpSource(timestamp_ms, kSsrc, RtpSourceType::SSRC,
266                                     kRtpTimestamp, extensions),
267                           RtpSource(timestamp_ms, kCsrcs1, RtpSourceType::CSRC,
268                                     kRtpTimestamp, extensions),
269                           RtpSource(timestamp_ms, kCsrcs0, RtpSourceType::CSRC,
270                                     kRtpTimestamp, extensions)));
271 }
272 
TEST(SourceTrackerTest,OnFrameDeliveredUpdatesSources)273 TEST(SourceTrackerTest, OnFrameDeliveredUpdatesSources) {
274   constexpr uint32_t kSsrc = 10;
275   constexpr uint32_t kCsrcs0 = 20;
276   constexpr uint32_t kCsrcs1 = 21;
277   constexpr uint32_t kCsrcs2 = 22;
278   constexpr uint32_t kRtpTimestamp0 = 40;
279   constexpr uint32_t kRtpTimestamp1 = 41;
280   constexpr absl::optional<uint8_t> kAudioLevel0 = 50;
281   constexpr absl::optional<uint8_t> kAudioLevel1 = absl::nullopt;
282   constexpr absl::optional<AbsoluteCaptureTime> kAbsoluteCaptureTime0 =
283       AbsoluteCaptureTime{12, 34};
284   constexpr absl::optional<AbsoluteCaptureTime> kAbsoluteCaptureTime1 =
285       AbsoluteCaptureTime{56, 78};
286   constexpr int64_t kReceiveTimeMs0 = 60;
287   constexpr int64_t kReceiveTimeMs1 = 61;
288 
289   SimulatedClock clock(1000000000000ULL);
290   SourceTracker tracker(&clock);
291 
292   tracker.OnFrameDelivered(RtpPacketInfos(
293       {RtpPacketInfo(kSsrc, {kCsrcs0, kCsrcs1}, kRtpTimestamp0, kAudioLevel0,
294                      kAbsoluteCaptureTime0, kReceiveTimeMs0)}));
295 
296   int64_t timestamp_ms_0 = clock.TimeInMilliseconds();
297 
298   clock.AdvanceTimeMilliseconds(17);
299 
300   tracker.OnFrameDelivered(RtpPacketInfos(
301       {RtpPacketInfo(kSsrc, {kCsrcs0, kCsrcs2}, kRtpTimestamp1, kAudioLevel1,
302                      kAbsoluteCaptureTime1, kReceiveTimeMs1)}));
303 
304   int64_t timestamp_ms_1 = clock.TimeInMilliseconds();
305 
306   constexpr RtpSource::Extensions extensions0 = {kAudioLevel0,
307                                                  kAbsoluteCaptureTime0};
308   constexpr RtpSource::Extensions extensions1 = {kAudioLevel1,
309                                                  kAbsoluteCaptureTime1};
310 
311   EXPECT_THAT(
312       tracker.GetSources(),
313       ElementsAre(RtpSource(timestamp_ms_1, kSsrc, RtpSourceType::SSRC,
314                             kRtpTimestamp1, extensions1),
315                   RtpSource(timestamp_ms_1, kCsrcs2, RtpSourceType::CSRC,
316                             kRtpTimestamp1, extensions1),
317                   RtpSource(timestamp_ms_1, kCsrcs0, RtpSourceType::CSRC,
318                             kRtpTimestamp1, extensions1),
319                   RtpSource(timestamp_ms_0, kCsrcs1, RtpSourceType::CSRC,
320                             kRtpTimestamp0, extensions0)));
321 }
322 
TEST(SourceTrackerTest,TimedOutSourcesAreRemoved)323 TEST(SourceTrackerTest, TimedOutSourcesAreRemoved) {
324   constexpr uint32_t kSsrc = 10;
325   constexpr uint32_t kCsrcs0 = 20;
326   constexpr uint32_t kCsrcs1 = 21;
327   constexpr uint32_t kCsrcs2 = 22;
328   constexpr uint32_t kRtpTimestamp0 = 40;
329   constexpr uint32_t kRtpTimestamp1 = 41;
330   constexpr absl::optional<uint8_t> kAudioLevel0 = 50;
331   constexpr absl::optional<uint8_t> kAudioLevel1 = absl::nullopt;
332   constexpr absl::optional<AbsoluteCaptureTime> kAbsoluteCaptureTime0 =
333       AbsoluteCaptureTime{12, 34};
334   constexpr absl::optional<AbsoluteCaptureTime> kAbsoluteCaptureTime1 =
335       AbsoluteCaptureTime{56, 78};
336   constexpr int64_t kReceiveTimeMs0 = 60;
337   constexpr int64_t kReceiveTimeMs1 = 61;
338 
339   SimulatedClock clock(1000000000000ULL);
340   SourceTracker tracker(&clock);
341 
342   tracker.OnFrameDelivered(RtpPacketInfos(
343       {RtpPacketInfo(kSsrc, {kCsrcs0, kCsrcs1}, kRtpTimestamp0, kAudioLevel0,
344                      kAbsoluteCaptureTime0, kReceiveTimeMs0)}));
345 
346   clock.AdvanceTimeMilliseconds(17);
347 
348   tracker.OnFrameDelivered(RtpPacketInfos(
349       {RtpPacketInfo(kSsrc, {kCsrcs0, kCsrcs2}, kRtpTimestamp1, kAudioLevel1,
350                      kAbsoluteCaptureTime1, kReceiveTimeMs1)}));
351 
352   int64_t timestamp_ms_1 = clock.TimeInMilliseconds();
353 
354   clock.AdvanceTimeMilliseconds(SourceTracker::kTimeoutMs);
355 
356   constexpr RtpSource::Extensions extensions1 = {kAudioLevel1,
357                                                  kAbsoluteCaptureTime1};
358 
359   EXPECT_THAT(
360       tracker.GetSources(),
361       ElementsAre(RtpSource(timestamp_ms_1, kSsrc, RtpSourceType::SSRC,
362                             kRtpTimestamp1, extensions1),
363                   RtpSource(timestamp_ms_1, kCsrcs2, RtpSourceType::CSRC,
364                             kRtpTimestamp1, extensions1),
365                   RtpSource(timestamp_ms_1, kCsrcs0, RtpSourceType::CSRC,
366                             kRtpTimestamp1, extensions1)));
367 }
368 
369 }  // namespace webrtc
370