• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright (c) 2014 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_coding/neteq/tools/neteq_quality_test.h"
12 
13 #include <stdio.h>
14 
15 #include <cmath>
16 
17 #include "absl/flags/flag.h"
18 #include "modules/audio_coding/neteq/default_neteq_factory.h"
19 #include "modules/audio_coding/neteq/tools/neteq_quality_test.h"
20 #include "modules/audio_coding/neteq/tools/output_audio_file.h"
21 #include "modules/audio_coding/neteq/tools/output_wav_file.h"
22 #include "modules/audio_coding/neteq/tools/resample_input_audio_file.h"
23 #include "rtc_base/checks.h"
24 #include "system_wrappers/include/clock.h"
25 #include "test/testsupport/file_utils.h"
26 
DefaultInFilename()27 const std::string& DefaultInFilename() {
28   static const std::string path =
29       ::webrtc::test::ResourcePath("audio_coding/speech_mono_16kHz", "pcm");
30   return path;
31 }
32 
DefaultOutFilename()33 const std::string& DefaultOutFilename() {
34   static const std::string path =
35       ::webrtc::test::OutputPath() + "neteq_quality_test_out.pcm";
36   return path;
37 }
38 
39 ABSL_FLAG(
40     std::string,
41     in_filename,
42     DefaultInFilename(),
43     "Filename for input audio (specify sample rate with --input_sample_rate, "
44     "and channels with --channels).");
45 
46 ABSL_FLAG(int, input_sample_rate, 16000, "Sample rate of input file in Hz.");
47 
48 ABSL_FLAG(int, channels, 1, "Number of channels in input audio.");
49 
50 ABSL_FLAG(std::string,
51           out_filename,
52           DefaultOutFilename(),
53           "Name of output audio file.");
54 
55 ABSL_FLAG(
56     int,
57     runtime_ms,
58     10000,
59     "Simulated runtime (milliseconds). -1 will consume the complete file.");
60 
61 ABSL_FLAG(int, packet_loss_rate, 10, "Percentile of packet loss.");
62 
63 ABSL_FLAG(int,
64           random_loss_mode,
65           ::webrtc::test::kUniformLoss,
66           "Random loss mode: 0--no loss, 1--uniform loss, 2--Gilbert Elliot "
67           "loss, 3--fixed loss.");
68 
69 ABSL_FLAG(int,
70           burst_length,
71           30,
72           "Burst length in milliseconds, only valid for Gilbert Elliot loss.");
73 
74 ABSL_FLAG(float, drift_factor, 0.0, "Time drift factor.");
75 
76 ABSL_FLAG(int,
77           preload_packets,
78           1,
79           "Preload the buffer with this many packets.");
80 
81 ABSL_FLAG(std::string,
82           loss_events,
83           "",
84           "List of loss events time and duration separated by comma: "
85           "<first_event_time> <first_event_duration>, <second_event_time> "
86           "<second_event_duration>, ...");
87 
88 namespace webrtc {
89 namespace test {
90 
91 namespace {
92 
CreateNetEq(const NetEq::Config & config,Clock * clock,const rtc::scoped_refptr<AudioDecoderFactory> & decoder_factory)93 std::unique_ptr<NetEq> CreateNetEq(
94     const NetEq::Config& config,
95     Clock* clock,
96     const rtc::scoped_refptr<AudioDecoderFactory>& decoder_factory) {
97   return DefaultNetEqFactory().CreateNetEq(config, decoder_factory, clock);
98 }
99 
100 }  // namespace
101 
102 const uint8_t kPayloadType = 95;
103 const int kOutputSizeMs = 10;
104 const int kInitSeed = 0x12345678;
105 const int kPacketLossTimeUnitMs = 10;
106 
107 // Common validator for file names.
ValidateFilename(const std::string & value,bool is_output)108 static bool ValidateFilename(const std::string& value, bool is_output) {
109   if (!is_output) {
110     RTC_CHECK_NE(value.substr(value.find_last_of(".") + 1), "wav")
111         << "WAV file input is not supported";
112   }
113   FILE* fid =
114       is_output ? fopen(value.c_str(), "wb") : fopen(value.c_str(), "rb");
115   if (fid == nullptr)
116     return false;
117   fclose(fid);
118   return true;
119 }
120 
121 // ProbTrans00Solver() is to calculate the transition probability from no-loss
122 // state to itself in a modified Gilbert Elliot packet loss model. The result is
123 // to achieve the target packet loss rate |loss_rate|, when a packet is not
124 // lost only if all |units| drawings within the duration of the packet result in
125 // no-loss.
ProbTrans00Solver(int units,double loss_rate,double prob_trans_10)126 static double ProbTrans00Solver(int units,
127                                 double loss_rate,
128                                 double prob_trans_10) {
129   if (units == 1)
130     return prob_trans_10 / (1.0f - loss_rate) - prob_trans_10;
131   // 0 == prob_trans_00 ^ (units - 1) + (1 - loss_rate) / prob_trans_10 *
132   //     prob_trans_00 - (1 - loss_rate) * (1 + 1 / prob_trans_10).
133   // There is a unique solution between 0.0 and 1.0, due to the monotonicity and
134   // an opposite sign at 0.0 and 1.0.
135   // For simplicity, we reformulate the equation as
136   //     f(x) = x ^ (units - 1) + a x + b.
137   // Its derivative is
138   //     f'(x) = (units - 1) x ^ (units - 2) + a.
139   // The derivative is strictly greater than 0 when x is between 0 and 1.
140   // We use Newton's method to solve the equation, iteration is
141   //     x(k+1) = x(k) - f(x) / f'(x);
142   const double kPrecision = 0.001f;
143   const int kIterations = 100;
144   const double a = (1.0f - loss_rate) / prob_trans_10;
145   const double b = (loss_rate - 1.0f) * (1.0f + 1.0f / prob_trans_10);
146   double x = 0.0;  // Starting point;
147   double f = b;
148   double f_p;
149   int iter = 0;
150   while ((f >= kPrecision || f <= -kPrecision) && iter < kIterations) {
151     f_p = (units - 1.0f) * std::pow(x, units - 2) + a;
152     x -= f / f_p;
153     if (x > 1.0f) {
154       x = 1.0f;
155     } else if (x < 0.0f) {
156       x = 0.0f;
157     }
158     f = std::pow(x, units - 1) + a * x + b;
159     iter++;
160   }
161   return x;
162 }
163 
NetEqQualityTest(int block_duration_ms,int in_sampling_khz,int out_sampling_khz,const SdpAudioFormat & format,const rtc::scoped_refptr<AudioDecoderFactory> & decoder_factory)164 NetEqQualityTest::NetEqQualityTest(
165     int block_duration_ms,
166     int in_sampling_khz,
167     int out_sampling_khz,
168     const SdpAudioFormat& format,
169     const rtc::scoped_refptr<AudioDecoderFactory>& decoder_factory)
170     : audio_format_(format),
171       channels_(absl::GetFlag(FLAGS_channels)),
172       decoded_time_ms_(0),
173       decodable_time_ms_(0),
174       drift_factor_(absl::GetFlag(FLAGS_drift_factor)),
175       packet_loss_rate_(absl::GetFlag(FLAGS_packet_loss_rate)),
176       block_duration_ms_(block_duration_ms),
177       in_sampling_khz_(in_sampling_khz),
178       out_sampling_khz_(out_sampling_khz),
179       in_size_samples_(
180           static_cast<size_t>(in_sampling_khz_ * block_duration_ms_)),
181       payload_size_bytes_(0),
182       max_payload_bytes_(0),
183       in_file_(
184           new ResampleInputAudioFile(absl::GetFlag(FLAGS_in_filename),
185                                      absl::GetFlag(FLAGS_input_sample_rate),
186                                      in_sampling_khz * 1000,
187                                      absl::GetFlag(FLAGS_runtime_ms) > 0)),
188       rtp_generator_(
189           new RtpGenerator(in_sampling_khz_, 0, 0, decodable_time_ms_)),
190       total_payload_size_bytes_(0) {
191   // Flag validation
192   RTC_CHECK(ValidateFilename(absl::GetFlag(FLAGS_in_filename), false))
193       << "Invalid input filename.";
194 
195   RTC_CHECK(absl::GetFlag(FLAGS_input_sample_rate) == 8000 ||
196             absl::GetFlag(FLAGS_input_sample_rate) == 16000 ||
197             absl::GetFlag(FLAGS_input_sample_rate) == 32000 ||
198             absl::GetFlag(FLAGS_input_sample_rate) == 48000)
199       << "Invalid sample rate should be 8000, 16000, 32000 or 48000 Hz.";
200 
201   RTC_CHECK_EQ(absl::GetFlag(FLAGS_channels), 1)
202       << "Invalid number of channels, current support only 1.";
203 
204   RTC_CHECK(ValidateFilename(absl::GetFlag(FLAGS_out_filename), true))
205       << "Invalid output filename.";
206 
207   RTC_CHECK(absl::GetFlag(FLAGS_packet_loss_rate) >= 0 &&
208             absl::GetFlag(FLAGS_packet_loss_rate) <= 100)
209       << "Invalid packet loss percentile, should be between 0 and 100.";
210 
211   RTC_CHECK(absl::GetFlag(FLAGS_random_loss_mode) >= 0 &&
212             absl::GetFlag(FLAGS_random_loss_mode) < kLastLossMode)
213       << "Invalid random packet loss mode, should be between 0 and "
214       << kLastLossMode - 1 << ".";
215 
216   RTC_CHECK_GE(absl::GetFlag(FLAGS_burst_length), kPacketLossTimeUnitMs)
217       << "Invalid burst length, should be greater than or equal to "
218       << kPacketLossTimeUnitMs << " ms.";
219 
220   RTC_CHECK_GT(absl::GetFlag(FLAGS_drift_factor), -0.1)
221       << "Invalid drift factor, should be greater than -0.1.";
222 
223   RTC_CHECK_GE(absl::GetFlag(FLAGS_preload_packets), 0)
224       << "Invalid number of packets to preload; must be non-negative.";
225 
226   const std::string out_filename = absl::GetFlag(FLAGS_out_filename);
227   const std::string log_filename = out_filename + ".log";
228   log_file_.open(log_filename.c_str(), std::ofstream::out);
229   RTC_CHECK(log_file_.is_open());
230 
231   if (out_filename.size() >= 4 &&
232       out_filename.substr(out_filename.size() - 4) == ".wav") {
233     // Open a wav file.
234     output_.reset(
235         new webrtc::test::OutputWavFile(out_filename, 1000 * out_sampling_khz));
236   } else {
237     // Open a pcm file.
238     output_.reset(new webrtc::test::OutputAudioFile(out_filename));
239   }
240 
241   NetEq::Config config;
242   config.sample_rate_hz = out_sampling_khz_ * 1000;
243   neteq_ = CreateNetEq(config, Clock::GetRealTimeClock(), decoder_factory);
244   max_payload_bytes_ = in_size_samples_ * channels_ * sizeof(int16_t);
245   in_data_.reset(new int16_t[in_size_samples_ * channels_]);
246 }
247 
~NetEqQualityTest()248 NetEqQualityTest::~NetEqQualityTest() {
249   log_file_.close();
250 }
251 
Lost(int now_ms)252 bool NoLoss::Lost(int now_ms) {
253   return false;
254 }
255 
UniformLoss(double loss_rate)256 UniformLoss::UniformLoss(double loss_rate) : loss_rate_(loss_rate) {}
257 
Lost(int now_ms)258 bool UniformLoss::Lost(int now_ms) {
259   int drop_this = rand();
260   return (drop_this < loss_rate_ * RAND_MAX);
261 }
262 
GilbertElliotLoss(double prob_trans_11,double prob_trans_01)263 GilbertElliotLoss::GilbertElliotLoss(double prob_trans_11, double prob_trans_01)
264     : prob_trans_11_(prob_trans_11),
265       prob_trans_01_(prob_trans_01),
266       lost_last_(false),
267       uniform_loss_model_(new UniformLoss(0)) {}
268 
~GilbertElliotLoss()269 GilbertElliotLoss::~GilbertElliotLoss() {}
270 
Lost(int now_ms)271 bool GilbertElliotLoss::Lost(int now_ms) {
272   // Simulate bursty channel (Gilbert model).
273   // (1st order) Markov chain model with memory of the previous/last
274   // packet state (lost or received).
275   if (lost_last_) {
276     // Previous packet was not received.
277     uniform_loss_model_->set_loss_rate(prob_trans_11_);
278     return lost_last_ = uniform_loss_model_->Lost(now_ms);
279   } else {
280     uniform_loss_model_->set_loss_rate(prob_trans_01_);
281     return lost_last_ = uniform_loss_model_->Lost(now_ms);
282   }
283 }
284 
FixedLossModel(std::set<FixedLossEvent,FixedLossEventCmp> loss_events)285 FixedLossModel::FixedLossModel(
286     std::set<FixedLossEvent, FixedLossEventCmp> loss_events)
287     : loss_events_(loss_events) {
288   loss_events_it_ = loss_events_.begin();
289 }
290 
~FixedLossModel()291 FixedLossModel::~FixedLossModel() {}
292 
Lost(int now_ms)293 bool FixedLossModel::Lost(int now_ms) {
294   if (loss_events_it_ != loss_events_.end() &&
295       now_ms > loss_events_it_->start_ms) {
296     if (now_ms <= loss_events_it_->start_ms + loss_events_it_->duration_ms) {
297       return true;
298     } else {
299       ++loss_events_it_;
300       return false;
301     }
302   }
303   return false;
304 }
305 
SetUp()306 void NetEqQualityTest::SetUp() {
307   ASSERT_TRUE(neteq_->RegisterPayloadType(kPayloadType, audio_format_));
308   rtp_generator_->set_drift_factor(drift_factor_);
309 
310   int units = block_duration_ms_ / kPacketLossTimeUnitMs;
311   switch (absl::GetFlag(FLAGS_random_loss_mode)) {
312     case kUniformLoss: {
313       // |unit_loss_rate| is the packet loss rate for each unit time interval
314       // (kPacketLossTimeUnitMs). Since a packet loss event is generated if any
315       // of |block_duration_ms_ / kPacketLossTimeUnitMs| unit time intervals of
316       // a full packet duration is drawn with a loss, |unit_loss_rate| fulfills
317       // (1 - unit_loss_rate) ^ (block_duration_ms_ / kPacketLossTimeUnitMs) ==
318       // 1 - packet_loss_rate.
319       double unit_loss_rate =
320           (1.0 - std::pow(1.0 - 0.01 * packet_loss_rate_, 1.0 / units));
321       loss_model_.reset(new UniformLoss(unit_loss_rate));
322       break;
323     }
324     case kGilbertElliotLoss: {
325       // |FLAGS_burst_length| should be integer times of kPacketLossTimeUnitMs.
326       ASSERT_EQ(0, absl::GetFlag(FLAGS_burst_length) % kPacketLossTimeUnitMs);
327 
328       // We do not allow 100 percent packet loss in Gilbert Elliot model, which
329       // makes no sense.
330       ASSERT_GT(100, packet_loss_rate_);
331 
332       // To guarantee the overall packet loss rate, transition probabilities
333       // need to satisfy:
334       // pi_0 * (1 - prob_trans_01_) ^ units +
335       //     pi_1 * prob_trans_10_ ^ (units - 1) == 1 - loss_rate
336       // pi_0 = prob_trans_10 / (prob_trans_10 + prob_trans_01_)
337       //     is the stationary state probability of no-loss
338       // pi_1 = prob_trans_01_ / (prob_trans_10 + prob_trans_01_)
339       //     is the stationary state probability of loss
340       // After a derivation prob_trans_00 should satisfy:
341       // prob_trans_00 ^ (units - 1) = (loss_rate - 1) / prob_trans_10 *
342       //     prob_trans_00 + (1 - loss_rate) * (1 + 1 / prob_trans_10).
343       double loss_rate = 0.01f * packet_loss_rate_;
344       double prob_trans_10 =
345           1.0f * kPacketLossTimeUnitMs / absl::GetFlag(FLAGS_burst_length);
346       double prob_trans_00 = ProbTrans00Solver(units, loss_rate, prob_trans_10);
347       loss_model_.reset(
348           new GilbertElliotLoss(1.0f - prob_trans_10, 1.0f - prob_trans_00));
349       break;
350     }
351     case kFixedLoss: {
352       std::istringstream loss_events_stream(absl::GetFlag(FLAGS_loss_events));
353       std::string loss_event_string;
354       std::set<FixedLossEvent, FixedLossEventCmp> loss_events;
355       while (std::getline(loss_events_stream, loss_event_string, ',')) {
356         std::vector<int> loss_event_params;
357         std::istringstream loss_event_params_stream(loss_event_string);
358         std::copy(std::istream_iterator<int>(loss_event_params_stream),
359                   std::istream_iterator<int>(),
360                   std::back_inserter(loss_event_params));
361         RTC_CHECK_EQ(loss_event_params.size(), 2);
362         auto result = loss_events.insert(
363             FixedLossEvent(loss_event_params[0], loss_event_params[1]));
364         RTC_CHECK(result.second);
365       }
366       RTC_CHECK_GT(loss_events.size(), 0);
367       loss_model_.reset(new FixedLossModel(loss_events));
368       break;
369     }
370     default: {
371       loss_model_.reset(new NoLoss);
372       break;
373     }
374   }
375 
376   // Make sure that the packet loss profile is same for all derived tests.
377   srand(kInitSeed);
378 }
379 
Log()380 std::ofstream& NetEqQualityTest::Log() {
381   return log_file_;
382 }
383 
PacketLost()384 bool NetEqQualityTest::PacketLost() {
385   int cycles = block_duration_ms_ / kPacketLossTimeUnitMs;
386 
387   // The loop is to make sure that codecs with different block lengths share the
388   // same packet loss profile.
389   bool lost = false;
390   for (int idx = 0; idx < cycles; idx++) {
391     if (loss_model_->Lost(decoded_time_ms_)) {
392       // The packet will be lost if any of the drawings indicates a loss, but
393       // the loop has to go on to make sure that codecs with different block
394       // lengths keep the same pace.
395       lost = true;
396     }
397   }
398   return lost;
399 }
400 
Transmit()401 int NetEqQualityTest::Transmit() {
402   int packet_input_time_ms = rtp_generator_->GetRtpHeader(
403       kPayloadType, in_size_samples_, &rtp_header_);
404   Log() << "Packet of size " << payload_size_bytes_ << " bytes, for frame at "
405         << packet_input_time_ms << " ms ";
406   if (payload_size_bytes_ > 0) {
407     if (!PacketLost()) {
408       int ret = neteq_->InsertPacket(
409           rtp_header_,
410           rtc::ArrayView<const uint8_t>(payload_.data(), payload_size_bytes_));
411       if (ret != NetEq::kOK)
412         return -1;
413       Log() << "was sent.";
414     } else {
415       Log() << "was lost.";
416     }
417   }
418   Log() << std::endl;
419   return packet_input_time_ms;
420 }
421 
DecodeBlock()422 int NetEqQualityTest::DecodeBlock() {
423   bool muted;
424   int ret = neteq_->GetAudio(&out_frame_, &muted);
425   RTC_CHECK(!muted);
426 
427   if (ret != NetEq::kOK) {
428     return -1;
429   } else {
430     RTC_DCHECK_EQ(out_frame_.num_channels_, channels_);
431     RTC_DCHECK_EQ(out_frame_.samples_per_channel_,
432                   static_cast<size_t>(kOutputSizeMs * out_sampling_khz_));
433     RTC_CHECK(output_->WriteArray(
434         out_frame_.data(),
435         out_frame_.samples_per_channel_ * out_frame_.num_channels_));
436     return static_cast<int>(out_frame_.samples_per_channel_);
437   }
438 }
439 
Simulate()440 void NetEqQualityTest::Simulate() {
441   int audio_size_samples;
442   bool end_of_input = false;
443   int runtime_ms = absl::GetFlag(FLAGS_runtime_ms) >= 0
444                        ? absl::GetFlag(FLAGS_runtime_ms)
445                        : INT_MAX;
446 
447   while (!end_of_input && decoded_time_ms_ < runtime_ms) {
448     // Preload the buffer if needed.
449     while (decodable_time_ms_ -
450                absl::GetFlag(FLAGS_preload_packets) * block_duration_ms_ <
451            decoded_time_ms_) {
452       if (!in_file_->Read(in_size_samples_ * channels_, &in_data_[0])) {
453         end_of_input = true;
454         ASSERT_TRUE(end_of_input && absl::GetFlag(FLAGS_runtime_ms) < 0);
455         break;
456       }
457       payload_.Clear();
458       payload_size_bytes_ = EncodeBlock(&in_data_[0], in_size_samples_,
459                                         &payload_, max_payload_bytes_);
460       total_payload_size_bytes_ += payload_size_bytes_;
461       decodable_time_ms_ = Transmit() + block_duration_ms_;
462     }
463     audio_size_samples = DecodeBlock();
464     if (audio_size_samples > 0) {
465       decoded_time_ms_ += audio_size_samples / out_sampling_khz_;
466     }
467   }
468   Log() << "Average bit rate was "
469         << 8.0f * total_payload_size_bytes_ / absl::GetFlag(FLAGS_runtime_ms)
470         << " kbps" << std::endl;
471 }
472 
473 }  // namespace test
474 }  // namespace webrtc
475