1 /*
2 * Copyright (c) 2014 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include <math.h>
12 #include <stdio.h>
13 #include "webrtc/base/checks.h"
14 #include "webrtc/modules/audio_coding/neteq/tools/neteq_quality_test.h"
15 #include "webrtc/modules/audio_coding/neteq/tools/output_audio_file.h"
16 #include "webrtc/modules/audio_coding/neteq/tools/output_wav_file.h"
17 #include "webrtc/modules/audio_coding/neteq/tools/resample_input_audio_file.h"
18 #include "webrtc/test/testsupport/fileutils.h"
19
20 using std::string;
21
22 namespace webrtc {
23 namespace test {
24
25 const uint8_t kPayloadType = 95;
26 const int kOutputSizeMs = 10;
27 const int kInitSeed = 0x12345678;
28 const int kPacketLossTimeUnitMs = 10;
29
30 // Common validator for file names.
ValidateFilename(const string & value,bool write)31 static bool ValidateFilename(const string& value, bool write) {
32 FILE* fid = write ? fopen(value.c_str(), "wb") : fopen(value.c_str(), "rb");
33 if (fid == nullptr)
34 return false;
35 fclose(fid);
36 return true;
37 }
38
39 // Define switch for input file name.
ValidateInFilename(const char * flagname,const string & value)40 static bool ValidateInFilename(const char* flagname, const string& value) {
41 if (!ValidateFilename(value, false)) {
42 printf("Invalid input filename.");
43 return false;
44 }
45 return true;
46 }
47
48 DEFINE_string(
49 in_filename,
50 ResourcePath("audio_coding/speech_mono_16kHz", "pcm"),
51 "Filename for input audio (specify sample rate with --input_sample_rate ,"
52 "and channels with --channels).");
53
54 static const bool in_filename_dummy =
55 RegisterFlagValidator(&FLAGS_in_filename, &ValidateInFilename);
56
57 // Define switch for sample rate.
ValidateSampleRate(const char * flagname,int32_t value)58 static bool ValidateSampleRate(const char* flagname, int32_t value) {
59 if (value == 8000 || value == 16000 || value == 32000 || value == 48000)
60 return true;
61 printf("Invalid sample rate should be 8000, 16000, 32000 or 48000 Hz.");
62 return false;
63 }
64
65 DEFINE_int32(input_sample_rate, 16000, "Sample rate of input file in Hz.");
66
67 static const bool sample_rate_dummy =
68 RegisterFlagValidator(&FLAGS_input_sample_rate, &ValidateSampleRate);
69
70 // Define switch for channels.
ValidateChannels(const char * flagname,int32_t value)71 static bool ValidateChannels(const char* flagname, int32_t value) {
72 if (value == 1)
73 return true;
74 printf("Invalid number of channels, current support only 1.");
75 return false;
76 }
77
78 DEFINE_int32(channels, 1, "Number of channels in input audio.");
79
80 static const bool channels_dummy =
81 RegisterFlagValidator(&FLAGS_channels, &ValidateChannels);
82
83 // Define switch for output file name.
ValidateOutFilename(const char * flagname,const string & value)84 static bool ValidateOutFilename(const char* flagname, const string& value) {
85 if (!ValidateFilename(value, true)) {
86 printf("Invalid output filename.");
87 return false;
88 }
89 return true;
90 }
91
92 DEFINE_string(out_filename,
93 OutputPath() + "neteq_quality_test_out.pcm",
94 "Name of output audio file.");
95
96 static const bool out_filename_dummy =
97 RegisterFlagValidator(&FLAGS_out_filename, &ValidateOutFilename);
98
99 // Define switch for packet loss rate.
ValidatePacketLossRate(const char *,int32_t value)100 static bool ValidatePacketLossRate(const char* /* flag_name */, int32_t value) {
101 if (value >= 0 && value <= 100)
102 return true;
103 printf("Invalid packet loss percentile, should be between 0 and 100.");
104 return false;
105 }
106
107 // Define switch for runtime.
ValidateRuntime(const char * flagname,int32_t value)108 static bool ValidateRuntime(const char* flagname, int32_t value) {
109 if (value > 0)
110 return true;
111 printf("Invalid runtime, should be greater than 0.");
112 return false;
113 }
114
115 DEFINE_int32(runtime_ms, 10000, "Simulated runtime (milliseconds).");
116
117 static const bool runtime_dummy =
118 RegisterFlagValidator(&FLAGS_runtime_ms, &ValidateRuntime);
119
120 DEFINE_int32(packet_loss_rate, 10, "Percentile of packet loss.");
121
122 static const bool packet_loss_rate_dummy =
123 RegisterFlagValidator(&FLAGS_packet_loss_rate, &ValidatePacketLossRate);
124
125 // Define switch for random loss mode.
ValidateRandomLossMode(const char *,int32_t value)126 static bool ValidateRandomLossMode(const char* /* flag_name */, int32_t value) {
127 if (value >= 0 && value <= 2)
128 return true;
129 printf("Invalid random packet loss mode, should be between 0 and 2.");
130 return false;
131 }
132
133 DEFINE_int32(random_loss_mode, 1,
134 "Random loss mode: 0--no loss, 1--uniform loss, 2--Gilbert Elliot loss.");
135 static const bool random_loss_mode_dummy =
136 RegisterFlagValidator(&FLAGS_random_loss_mode, &ValidateRandomLossMode);
137
138 // Define switch for burst length.
ValidateBurstLength(const char *,int32_t value)139 static bool ValidateBurstLength(const char* /* flag_name */, int32_t value) {
140 if (value >= kPacketLossTimeUnitMs)
141 return true;
142 printf("Invalid burst length, should be greater than %d ms.",
143 kPacketLossTimeUnitMs);
144 return false;
145 }
146
147 DEFINE_int32(burst_length, 30,
148 "Burst length in milliseconds, only valid for Gilbert Elliot loss.");
149
150 static const bool burst_length_dummy =
151 RegisterFlagValidator(&FLAGS_burst_length, &ValidateBurstLength);
152
153 // Define switch for drift factor.
ValidateDriftFactor(const char *,double value)154 static bool ValidateDriftFactor(const char* /* flag_name */, double value) {
155 if (value > -0.1)
156 return true;
157 printf("Invalid drift factor, should be greater than -0.1.");
158 return false;
159 }
160
161 DEFINE_double(drift_factor, 0.0, "Time drift factor.");
162
163 static const bool drift_factor_dummy =
164 RegisterFlagValidator(&FLAGS_drift_factor, &ValidateDriftFactor);
165
166 // ProbTrans00Solver() is to calculate the transition probability from no-loss
167 // state to itself in a modified Gilbert Elliot packet loss model. The result is
168 // to achieve the target packet loss rate |loss_rate|, when a packet is not
169 // lost only if all |units| drawings within the duration of the packet result in
170 // no-loss.
ProbTrans00Solver(int units,double loss_rate,double prob_trans_10)171 static double ProbTrans00Solver(int units, double loss_rate,
172 double prob_trans_10) {
173 if (units == 1)
174 return prob_trans_10 / (1.0f - loss_rate) - prob_trans_10;
175 // 0 == prob_trans_00 ^ (units - 1) + (1 - loss_rate) / prob_trans_10 *
176 // prob_trans_00 - (1 - loss_rate) * (1 + 1 / prob_trans_10).
177 // There is a unique solution between 0.0 and 1.0, due to the monotonicity and
178 // an opposite sign at 0.0 and 1.0.
179 // For simplicity, we reformulate the equation as
180 // f(x) = x ^ (units - 1) + a x + b.
181 // Its derivative is
182 // f'(x) = (units - 1) x ^ (units - 2) + a.
183 // The derivative is strictly greater than 0 when x is between 0 and 1.
184 // We use Newton's method to solve the equation, iteration is
185 // x(k+1) = x(k) - f(x) / f'(x);
186 const double kPrecision = 0.001f;
187 const int kIterations = 100;
188 const double a = (1.0f - loss_rate) / prob_trans_10;
189 const double b = (loss_rate - 1.0f) * (1.0f + 1.0f / prob_trans_10);
190 double x = 0.0f; // Starting point;
191 double f = b;
192 double f_p;
193 int iter = 0;
194 while ((f >= kPrecision || f <= -kPrecision) && iter < kIterations) {
195 f_p = (units - 1.0f) * pow(x, units - 2) + a;
196 x -= f / f_p;
197 if (x > 1.0f) {
198 x = 1.0f;
199 } else if (x < 0.0f) {
200 x = 0.0f;
201 }
202 f = pow(x, units - 1) + a * x + b;
203 iter ++;
204 }
205 return x;
206 }
207
NetEqQualityTest(int block_duration_ms,int in_sampling_khz,int out_sampling_khz,NetEqDecoder decoder_type)208 NetEqQualityTest::NetEqQualityTest(int block_duration_ms,
209 int in_sampling_khz,
210 int out_sampling_khz,
211 NetEqDecoder decoder_type)
212 : decoder_type_(decoder_type),
213 channels_(static_cast<size_t>(FLAGS_channels)),
214 decoded_time_ms_(0),
215 decodable_time_ms_(0),
216 drift_factor_(FLAGS_drift_factor),
217 packet_loss_rate_(FLAGS_packet_loss_rate),
218 block_duration_ms_(block_duration_ms),
219 in_sampling_khz_(in_sampling_khz),
220 out_sampling_khz_(out_sampling_khz),
221 in_size_samples_(
222 static_cast<size_t>(in_sampling_khz_ * block_duration_ms_)),
223 out_size_samples_(static_cast<size_t>(out_sampling_khz_ * kOutputSizeMs)),
224 payload_size_bytes_(0),
225 max_payload_bytes_(0),
226 in_file_(new ResampleInputAudioFile(FLAGS_in_filename,
227 FLAGS_input_sample_rate,
228 in_sampling_khz * 1000)),
229 rtp_generator_(
230 new RtpGenerator(in_sampling_khz_, 0, 0, decodable_time_ms_)),
231 total_payload_size_bytes_(0) {
232 const std::string out_filename = FLAGS_out_filename;
233 const std::string log_filename = out_filename + ".log";
234 log_file_.open(log_filename.c_str(), std::ofstream::out);
235 RTC_CHECK(log_file_.is_open());
236
237 if (out_filename.size() >= 4 &&
238 out_filename.substr(out_filename.size() - 4) == ".wav") {
239 // Open a wav file.
240 output_.reset(
241 new webrtc::test::OutputWavFile(out_filename, 1000 * out_sampling_khz));
242 } else {
243 // Open a pcm file.
244 output_.reset(new webrtc::test::OutputAudioFile(out_filename));
245 }
246
247 NetEq::Config config;
248 config.sample_rate_hz = out_sampling_khz_ * 1000;
249 neteq_.reset(NetEq::Create(config));
250 max_payload_bytes_ = in_size_samples_ * channels_ * sizeof(int16_t);
251 in_data_.reset(new int16_t[in_size_samples_ * channels_]);
252 payload_.reset(new uint8_t[max_payload_bytes_]);
253 out_data_.reset(new int16_t[out_size_samples_ * channels_]);
254 }
255
~NetEqQualityTest()256 NetEqQualityTest::~NetEqQualityTest() {
257 log_file_.close();
258 }
259
Lost()260 bool NoLoss::Lost() {
261 return false;
262 }
263
UniformLoss(double loss_rate)264 UniformLoss::UniformLoss(double loss_rate)
265 : loss_rate_(loss_rate) {
266 }
267
Lost()268 bool UniformLoss::Lost() {
269 int drop_this = rand();
270 return (drop_this < loss_rate_ * RAND_MAX);
271 }
272
GilbertElliotLoss(double prob_trans_11,double prob_trans_01)273 GilbertElliotLoss::GilbertElliotLoss(double prob_trans_11, double prob_trans_01)
274 : prob_trans_11_(prob_trans_11),
275 prob_trans_01_(prob_trans_01),
276 lost_last_(false),
277 uniform_loss_model_(new UniformLoss(0)) {
278 }
279
Lost()280 bool GilbertElliotLoss::Lost() {
281 // Simulate bursty channel (Gilbert model).
282 // (1st order) Markov chain model with memory of the previous/last
283 // packet state (lost or received).
284 if (lost_last_) {
285 // Previous packet was not received.
286 uniform_loss_model_->set_loss_rate(prob_trans_11_);
287 return lost_last_ = uniform_loss_model_->Lost();
288 } else {
289 uniform_loss_model_->set_loss_rate(prob_trans_01_);
290 return lost_last_ = uniform_loss_model_->Lost();
291 }
292 }
293
SetUp()294 void NetEqQualityTest::SetUp() {
295 ASSERT_EQ(0,
296 neteq_->RegisterPayloadType(decoder_type_, "noname", kPayloadType));
297 rtp_generator_->set_drift_factor(drift_factor_);
298
299 int units = block_duration_ms_ / kPacketLossTimeUnitMs;
300 switch (FLAGS_random_loss_mode) {
301 case 1: {
302 // |unit_loss_rate| is the packet loss rate for each unit time interval
303 // (kPacketLossTimeUnitMs). Since a packet loss event is generated if any
304 // of |block_duration_ms_ / kPacketLossTimeUnitMs| unit time intervals of
305 // a full packet duration is drawn with a loss, |unit_loss_rate| fulfills
306 // (1 - unit_loss_rate) ^ (block_duration_ms_ / kPacketLossTimeUnitMs) ==
307 // 1 - packet_loss_rate.
308 double unit_loss_rate = (1.0f - pow(1.0f - 0.01f * packet_loss_rate_,
309 1.0f / units));
310 loss_model_.reset(new UniformLoss(unit_loss_rate));
311 break;
312 }
313 case 2: {
314 // |FLAGS_burst_length| should be integer times of kPacketLossTimeUnitMs.
315 ASSERT_EQ(0, FLAGS_burst_length % kPacketLossTimeUnitMs);
316
317 // We do not allow 100 percent packet loss in Gilbert Elliot model, which
318 // makes no sense.
319 ASSERT_GT(100, packet_loss_rate_);
320
321 // To guarantee the overall packet loss rate, transition probabilities
322 // need to satisfy:
323 // pi_0 * (1 - prob_trans_01_) ^ units +
324 // pi_1 * prob_trans_10_ ^ (units - 1) == 1 - loss_rate
325 // pi_0 = prob_trans_10 / (prob_trans_10 + prob_trans_01_)
326 // is the stationary state probability of no-loss
327 // pi_1 = prob_trans_01_ / (prob_trans_10 + prob_trans_01_)
328 // is the stationary state probability of loss
329 // After a derivation prob_trans_00 should satisfy:
330 // prob_trans_00 ^ (units - 1) = (loss_rate - 1) / prob_trans_10 *
331 // prob_trans_00 + (1 - loss_rate) * (1 + 1 / prob_trans_10).
332 double loss_rate = 0.01f * packet_loss_rate_;
333 double prob_trans_10 = 1.0f * kPacketLossTimeUnitMs / FLAGS_burst_length;
334 double prob_trans_00 = ProbTrans00Solver(units, loss_rate, prob_trans_10);
335 loss_model_.reset(new GilbertElliotLoss(1.0f - prob_trans_10,
336 1.0f - prob_trans_00));
337 break;
338 }
339 default: {
340 loss_model_.reset(new NoLoss);
341 break;
342 }
343 }
344
345 // Make sure that the packet loss profile is same for all derived tests.
346 srand(kInitSeed);
347 }
348
Log()349 std::ofstream& NetEqQualityTest::Log() {
350 return log_file_;
351 }
352
PacketLost()353 bool NetEqQualityTest::PacketLost() {
354 int cycles = block_duration_ms_ / kPacketLossTimeUnitMs;
355
356 // The loop is to make sure that codecs with different block lengths share the
357 // same packet loss profile.
358 bool lost = false;
359 for (int idx = 0; idx < cycles; idx ++) {
360 if (loss_model_->Lost()) {
361 // The packet will be lost if any of the drawings indicates a loss, but
362 // the loop has to go on to make sure that codecs with different block
363 // lengths keep the same pace.
364 lost = true;
365 }
366 }
367 return lost;
368 }
369
Transmit()370 int NetEqQualityTest::Transmit() {
371 int packet_input_time_ms =
372 rtp_generator_->GetRtpHeader(kPayloadType, in_size_samples_,
373 &rtp_header_);
374 Log() << "Packet of size "
375 << payload_size_bytes_
376 << " bytes, for frame at "
377 << packet_input_time_ms
378 << " ms ";
379 if (payload_size_bytes_ > 0) {
380 if (!PacketLost()) {
381 int ret = neteq_->InsertPacket(
382 rtp_header_,
383 rtc::ArrayView<const uint8_t>(payload_.get(), payload_size_bytes_),
384 packet_input_time_ms * in_sampling_khz_);
385 if (ret != NetEq::kOK)
386 return -1;
387 Log() << "was sent.";
388 } else {
389 Log() << "was lost.";
390 }
391 }
392 Log() << std::endl;
393 return packet_input_time_ms;
394 }
395
DecodeBlock()396 int NetEqQualityTest::DecodeBlock() {
397 size_t channels;
398 size_t samples;
399 int ret = neteq_->GetAudio(out_size_samples_ * channels_, &out_data_[0],
400 &samples, &channels, NULL);
401
402 if (ret != NetEq::kOK) {
403 return -1;
404 } else {
405 assert(channels == channels_);
406 assert(samples == static_cast<size_t>(kOutputSizeMs * out_sampling_khz_));
407 RTC_CHECK(output_->WriteArray(out_data_.get(), samples * channels));
408 return static_cast<int>(samples);
409 }
410 }
411
Simulate()412 void NetEqQualityTest::Simulate() {
413 int audio_size_samples;
414
415 while (decoded_time_ms_ < FLAGS_runtime_ms) {
416 // Assume 10 packets in packets buffer.
417 while (decodable_time_ms_ - 10 * block_duration_ms_ < decoded_time_ms_) {
418 ASSERT_TRUE(in_file_->Read(in_size_samples_ * channels_, &in_data_[0]));
419 payload_size_bytes_ = EncodeBlock(&in_data_[0],
420 in_size_samples_, &payload_[0],
421 max_payload_bytes_);
422 total_payload_size_bytes_ += payload_size_bytes_;
423 decodable_time_ms_ = Transmit() + block_duration_ms_;
424 }
425 audio_size_samples = DecodeBlock();
426 if (audio_size_samples > 0) {
427 decoded_time_ms_ += audio_size_samples / out_sampling_khz_;
428 }
429 }
430 Log() << "Average bit rate was "
431 << 8.0f * total_payload_size_bytes_ / FLAGS_runtime_ms
432 << " kbps"
433 << std::endl;
434 }
435
436 } // namespace test
437 } // namespace webrtc
438