• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 3-Clause Clear
5  * License and the Alliance for Open Media Patent License 1.0. If the BSD
6  * 3-Clause Clear License was not distributed with this source code in the
7  * LICENSE file, you can obtain it at
8  * www.aomedia.org/license/software-license/bsd-3-c-c. If the Alliance for
9  * Open Media Patent License 1.0 was not distributed with this source code
10  * in the PATENTS file, you can obtain it at www.aomedia.org/license/patent.
11  */
12 #include "iamf/cli/demixing_module.h"
13 
14 #include <algorithm>
15 #include <array>
16 #include <cstdint>
17 #include <iterator>
18 #include <list>
19 #include <optional>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/status/status.h"
25 #include "absl/status/status_matchers.h"
26 #include "absl/status/statusor.h"
27 #include "absl/types/span.h"
28 #include "gmock/gmock.h"
29 #include "gtest/gtest.h"
30 #include "iamf/cli/audio_element_with_data.h"
31 #include "iamf/cli/audio_frame_decoder.h"
32 #include "iamf/cli/audio_frame_with_data.h"
33 #include "iamf/cli/channel_label.h"
34 #include "iamf/cli/proto/user_metadata.pb.h"
35 #include "iamf/cli/proto_conversion/channel_label_utils.h"
36 #include "iamf/cli/proto_conversion/downmixing_reconstruction_util.h"
37 #include "iamf/cli/tests/cli_test_utils.h"
38 #include "iamf/common/utils/numeric_utils.h"
39 #include "iamf/obu/audio_element.h"
40 #include "iamf/obu/audio_frame.h"
41 #include "iamf/obu/codec_config.h"
42 #include "iamf/obu/demixing_info_parameter_data.h"
43 #include "iamf/obu/obu_header.h"
44 #include "iamf/obu/recon_gain_info_parameter_data.h"
45 #include "iamf/obu/types.h"
46 
47 namespace iamf_tools {
48 namespace {
49 
50 using ::absl_testing::IsOk;
51 using ::absl_testing::IsOkAndHolds;
52 using enum ChannelLabel::Label;
53 using ::testing::DoubleEq;
54 using ::testing::DoubleNear;
55 using ::testing::IsEmpty;
56 using ::testing::Not;
57 using ::testing::Pointwise;
58 
59 constexpr DecodedUleb128 kAudioElementId = 137;
60 constexpr std::array<uint8_t, 12> kReconGainValues = {
61     255, 0, 125, 200, 150, 255, 255, 255, 255, 255, 255, 255};
62 const uint32_t kZeroSamplesToTrimAtEnd = 0;
63 const uint32_t kZeroSamplesToTrimAtStart = 0;
64 const int kStartTimestamp = 0;
65 const int kEndTimestamp = 4;
66 const DecodedUleb128 kMonoSubstreamId = 0;
67 const DecodedUleb128 kL2SubstreamId = 1;
68 
69 // TODO(b/305927287): Test computation of linear output gains. Test some cases
70 //                    of erroneous input.
71 
TEST(FindSamplesOrDemixedSamples,FindsMatchingSamples)72 TEST(FindSamplesOrDemixedSamples, FindsMatchingSamples) {
73   const std::vector<InternalSampleType> kSamplesToFind = {1, 2, 3};
74   const LabelSamplesMap kLabelToSamples = {{kL2, kSamplesToFind}};
75 
76   const std::vector<InternalSampleType>* found_samples;
77   EXPECT_THAT(DemixingModule::FindSamplesOrDemixedSamples(kL2, kLabelToSamples,
78                                                           &found_samples),
79               IsOk());
80   EXPECT_THAT(*found_samples, Pointwise(DoubleEq(), kSamplesToFind));
81 }
82 
TEST(FindSamplesOrDemixedSamples,FindsMatchingDemixedSamples)83 TEST(FindSamplesOrDemixedSamples, FindsMatchingDemixedSamples) {
84   const std::vector<InternalSampleType> kSamplesToFind = {1, 2, 3};
85   const LabelSamplesMap kLabelToSamples = {{kDemixedR2, kSamplesToFind}};
86 
87   const std::vector<InternalSampleType>* found_samples;
88   EXPECT_THAT(DemixingModule::FindSamplesOrDemixedSamples(kR2, kLabelToSamples,
89                                                           &found_samples),
90               IsOk());
91   EXPECT_THAT(*found_samples, Pointwise(DoubleEq(), kSamplesToFind));
92 }
93 
TEST(FindSamplesOrDemixedSamples,InvalidWhenThereIsNoDemixingLabel)94 TEST(FindSamplesOrDemixedSamples, InvalidWhenThereIsNoDemixingLabel) {
95   const std::vector<InternalSampleType> kSamplesToFind = {1, 2, 3};
96   const LabelSamplesMap kLabelToSamples = {{kDemixedR2, kSamplesToFind}};
97 
98   const std::vector<InternalSampleType>* found_samples;
99   EXPECT_FALSE(DemixingModule::FindSamplesOrDemixedSamples(kL2, kLabelToSamples,
100                                                            &found_samples)
101                    .ok());
102 }
103 
TEST(FindSamplesOrDemixedSamples,RegularSamplesTakePrecedence)104 TEST(FindSamplesOrDemixedSamples, RegularSamplesTakePrecedence) {
105   const std::vector<InternalSampleType> kSamplesToFind = {1, 2, 3};
106   const std::vector<InternalSampleType> kDemixedSamplesToIgnore = {4, 5, 6};
107   const LabelSamplesMap kLabelToSamples = {
108       {kR2, kSamplesToFind}, {kDemixedR2, kDemixedSamplesToIgnore}};
109   const std::vector<InternalSampleType>* found_samples;
110   EXPECT_THAT(DemixingModule::FindSamplesOrDemixedSamples(kR2, kLabelToSamples,
111                                                           &found_samples),
112               IsOk());
113   EXPECT_THAT(*found_samples, Pointwise(DoubleEq(), kSamplesToFind));
114 }
115 
TEST(FindSamplesOrDemixedSamples,ErrorNoMatchingSamples)116 TEST(FindSamplesOrDemixedSamples, ErrorNoMatchingSamples) {
117   const std::vector<InternalSampleType> kSamplesToFind = {1, 2, 3};
118   const LabelSamplesMap kLabelToSamples = {{kL2, kSamplesToFind}};
119 
120   const std::vector<InternalSampleType>* found_samples;
121   EXPECT_FALSE(DemixingModule::FindSamplesOrDemixedSamples(kL3, kLabelToSamples,
122                                                            &found_samples)
123                    .ok());
124 }
125 
InitAudioElementWithLabelsAndLayers(const SubstreamIdLabelsMap & substream_id_to_labels,const std::vector<ChannelAudioLayerConfig::LoudspeakerLayout> & loudspeaker_layouts,absl::flat_hash_map<DecodedUleb128,AudioElementWithData> & audio_elements)126 void InitAudioElementWithLabelsAndLayers(
127     const SubstreamIdLabelsMap& substream_id_to_labels,
128     const std::vector<ChannelAudioLayerConfig::LoudspeakerLayout>&
129         loudspeaker_layouts,
130     absl::flat_hash_map<DecodedUleb128, AudioElementWithData>& audio_elements) {
131   auto [iter, unused_inserted] = audio_elements.emplace(
132       kAudioElementId,
133       AudioElementWithData{
134           .obu = AudioElementObu(ObuHeader(), kAudioElementId,
135                                  AudioElementObu::kAudioElementChannelBased,
136                                  /*reserved=*/0,
137                                  /*codec_config_id=*/0),
138           .substream_id_to_labels = substream_id_to_labels,
139       });
140   auto& obu = iter->second.obu;
141   ASSERT_THAT(
142       obu.InitializeScalableChannelLayout(loudspeaker_layouts.size(), 0),
143       IsOk());
144   auto& config = std::get<ScalableChannelLayoutConfig>(obu.config_);
145   for (int i = 0; i < loudspeaker_layouts.size(); ++i) {
146     config.channel_audio_layer_configs[i].loudspeaker_layout =
147         loudspeaker_layouts[i];
148   }
149 }
150 
TEST(CreateForDownMixingAndReconstruction,EmptyConfigMapIsOk)151 TEST(CreateForDownMixingAndReconstruction, EmptyConfigMapIsOk) {
152   absl::flat_hash_map<DecodedUleb128,
153                       DemixingModule::DownmixingAndReconstructionConfig>
154       id_to_config_map;
155   const auto demixing_module =
156       DemixingModule::CreateForDownMixingAndReconstruction(
157           std::move(id_to_config_map));
158   EXPECT_THAT(demixing_module, IsOk());
159 }
160 
TEST(CreateForDownMixingAndReconstruction,ValidWithTwoLayerStereo)161 TEST(CreateForDownMixingAndReconstruction, ValidWithTwoLayerStereo) {
162   DecodedUleb128 id = 137;
163   DemixingModule::DownmixingAndReconstructionConfig config = {
164       .user_labels = {kL2, kR2},
165       .substream_id_to_labels = {{0, {kMono}}, {1, {kL2}}},
166       .label_to_output_gain = {}};
167   absl::flat_hash_map<DecodedUleb128,
168                       DemixingModule::DownmixingAndReconstructionConfig>
169       id_to_config_map = {{id, config}};
170   const auto demixing_module =
171       DemixingModule::CreateForDownMixingAndReconstruction(
172           std::move(id_to_config_map));
173   EXPECT_THAT(demixing_module, IsOk());
174 }
175 
TEST(InitializeForReconstruction,NeverCreatesDownMixers)176 TEST(InitializeForReconstruction, NeverCreatesDownMixers) {
177   absl::flat_hash_map<DecodedUleb128, AudioElementWithData> audio_elements;
178   InitAudioElementWithLabelsAndLayers({{0, {kMono}}, {1, {kL2}}},
179                                       {ChannelAudioLayerConfig::kLayoutMono,
180                                        ChannelAudioLayerConfig::kLayoutStereo},
181                                       audio_elements);
182   const auto demixing_module =
183       DemixingModule::CreateForReconstruction(audio_elements);
184   ASSERT_THAT(demixing_module, IsOk());
185 
186   const std::list<Demixer>* down_mixers = nullptr;
187   EXPECT_THAT(demixing_module->GetDownMixers(kAudioElementId, down_mixers),
188               IsOk());
189   EXPECT_TRUE(down_mixers->empty());
190 }
191 
TEST(CreateForReconstruction,CreatesOneDemixerForTwoLayerStereo)192 TEST(CreateForReconstruction, CreatesOneDemixerForTwoLayerStereo) {
193   absl::flat_hash_map<DecodedUleb128, AudioElementWithData> audio_elements;
194   InitAudioElementWithLabelsAndLayers({{0, {kMono}}, {1, {kL2}}},
195                                       {ChannelAudioLayerConfig::kLayoutMono,
196                                        ChannelAudioLayerConfig::kLayoutStereo},
197                                       audio_elements);
198   const auto demixing_module =
199       DemixingModule::CreateForReconstruction(audio_elements);
200   ASSERT_THAT(demixing_module, IsOk());
201 
202   const std::list<Demixer>* demixer = nullptr;
203   EXPECT_THAT(demixing_module->GetDemixers(kAudioElementId, demixer), IsOk());
204   EXPECT_EQ(demixer->size(), 1);
205 }
206 
TEST(CreateForReconstruction,FailsForReservedLayout14)207 TEST(CreateForReconstruction, FailsForReservedLayout14) {
208   absl::flat_hash_map<DecodedUleb128, AudioElementWithData> audio_elements;
209   InitAudioElementWithLabelsAndLayers(
210       {{0, {kOmitted}}}, {ChannelAudioLayerConfig::kLayoutReserved14},
211       audio_elements);
212 
213   const auto demixing_module =
214       DemixingModule::CreateForReconstruction(audio_elements);
215 
216   EXPECT_FALSE(demixing_module.ok());
217 }
218 
TEST(CreateForReconstruction,ValidForExpandedLayoutLFE)219 TEST(CreateForReconstruction, ValidForExpandedLayoutLFE) {
220   absl::flat_hash_map<DecodedUleb128, AudioElementWithData> audio_elements;
221   InitAudioElementWithLabelsAndLayers(
222       {{0, {kLFE}}}, {ChannelAudioLayerConfig::kLayoutExpanded},
223       audio_elements);
224   std::get<ScalableChannelLayoutConfig>(
225       audio_elements.at(kAudioElementId).obu.config_)
226       .channel_audio_layer_configs[0]
227       .expanded_loudspeaker_layout =
228       ChannelAudioLayerConfig::kExpandedLayoutLFE;
229 
230   const auto demixing_module =
231       DemixingModule::CreateForReconstruction(audio_elements);
232 
233   EXPECT_THAT(demixing_module, IsOk());
234 }
235 
TEST(CreateForReconstruction,CreatesNoDemixersForSingleLayerChannelBased)236 TEST(CreateForReconstruction, CreatesNoDemixersForSingleLayerChannelBased) {
237   absl::flat_hash_map<DecodedUleb128, AudioElementWithData> audio_elements;
238   InitAudioElementWithLabelsAndLayers({{0, {kL2, kR2}}},
239                                       {ChannelAudioLayerConfig::kLayoutStereo},
240                                       audio_elements);
241   const auto demixing_module =
242       DemixingModule::CreateForReconstruction(audio_elements);
243   ASSERT_THAT(demixing_module, IsOk());
244 
245   const std::list<Demixer>* demixer = nullptr;
246   EXPECT_THAT(demixing_module->GetDemixers(kAudioElementId, demixer), IsOk());
247   EXPECT_TRUE(demixer->empty());
248 }
249 
TEST(CreateForReconstruction,CreatesNoDemixersForAmbisonics)250 TEST(CreateForReconstruction, CreatesNoDemixersForAmbisonics) {
251   const DecodedUleb128 kCodecConfigId = 0;
252   constexpr std::array<DecodedUleb128, 4> kAmbisonicsSubstreamIds{0, 1, 2, 3};
253   absl::flat_hash_map<DecodedUleb128, CodecConfigObu> codec_configs;
254   AddLpcmCodecConfigWithIdAndSampleRate(kCodecConfigId, 48000, codec_configs);
255   absl::flat_hash_map<DecodedUleb128, AudioElementWithData> audio_elements;
256   AddAmbisonicsMonoAudioElementWithSubstreamIds(kAudioElementId, kCodecConfigId,
257                                                 kAmbisonicsSubstreamIds,
258                                                 codec_configs, audio_elements);
259 
260   const auto demixing_module =
261       DemixingModule::CreateForReconstruction(audio_elements);
262   ASSERT_THAT(demixing_module, IsOk());
263 
264   const std::list<Demixer>* demixer = nullptr;
265   EXPECT_THAT(demixing_module->GetDemixers(kAudioElementId, demixer), IsOk());
266   EXPECT_TRUE(demixer->empty());
267 }
268 
TEST(DemixOriginalAudioSamples,ReturnsErrorAfterCreateForReconstruction)269 TEST(DemixOriginalAudioSamples, ReturnsErrorAfterCreateForReconstruction) {
270   absl::flat_hash_map<DecodedUleb128, AudioElementWithData> audio_elements;
271   InitAudioElementWithLabelsAndLayers(
272       {{kMonoSubstreamId, {kMono}}, {kL2SubstreamId, {kL2}}},
273       {ChannelAudioLayerConfig::kLayoutMono,
274        ChannelAudioLayerConfig::kLayoutStereo},
275       audio_elements);
276   auto demixing_module =
277       DemixingModule::CreateForReconstruction(audio_elements);
278   ASSERT_THAT(demixing_module, IsOk());
279 
280   EXPECT_THAT(demixing_module->DemixOriginalAudioSamples({}), Not(IsOk()));
281 }
282 
TEST(DemixDecodedAudioSamples,OutputContainsOriginalAndDemixedSamples)283 TEST(DemixDecodedAudioSamples, OutputContainsOriginalAndDemixedSamples) {
284   absl::flat_hash_map<DecodedUleb128, AudioElementWithData> audio_elements;
285   InitAudioElementWithLabelsAndLayers(
286       {{kMonoSubstreamId, {kMono}}, {kL2SubstreamId, {kL2}}},
287       {ChannelAudioLayerConfig::kLayoutMono,
288        ChannelAudioLayerConfig::kLayoutStereo},
289       audio_elements);
290   std::list<DecodedAudioFrame> decoded_audio_frames;
291   decoded_audio_frames.push_back(
292       DecodedAudioFrame{.substream_id = kMonoSubstreamId,
293                         .start_timestamp = kStartTimestamp,
294                         .end_timestamp = kEndTimestamp,
295                         .samples_to_trim_at_end = kZeroSamplesToTrimAtEnd,
296                         .samples_to_trim_at_start = kZeroSamplesToTrimAtStart,
297                         .decoded_samples = {{0}},
298                         .down_mixing_params = DownMixingParams()});
299   decoded_audio_frames.push_back(
300       DecodedAudioFrame{.substream_id = kL2SubstreamId,
301                         .start_timestamp = kStartTimestamp,
302                         .end_timestamp = kEndTimestamp,
303                         .samples_to_trim_at_end = kZeroSamplesToTrimAtEnd,
304                         .samples_to_trim_at_start = kZeroSamplesToTrimAtStart,
305                         .decoded_samples = {{0}},
306                         .down_mixing_params = DownMixingParams()});
307   auto demixing_module =
308       DemixingModule::CreateForReconstruction(audio_elements);
309   ASSERT_THAT(demixing_module, IsOk());
310   const auto id_to_labeled_decoded_frame =
311       demixing_module->DemixDecodedAudioSamples(decoded_audio_frames);
312   ASSERT_THAT(id_to_labeled_decoded_frame, IsOk());
313   ASSERT_TRUE(id_to_labeled_decoded_frame->contains(kAudioElementId));
314 
315   const auto& labeled_frame = id_to_labeled_decoded_frame->at(kAudioElementId);
316   EXPECT_TRUE(labeled_frame.label_to_samples.contains(kL2));
317   EXPECT_TRUE(labeled_frame.label_to_samples.contains(kMono));
318   EXPECT_TRUE(labeled_frame.label_to_samples.contains(kDemixedR2));
319 }
320 
TEST(DemixDecodedAudioSamples,OutputEchoesTimingInformation)321 TEST(DemixDecodedAudioSamples, OutputEchoesTimingInformation) {
322   // These values are not very sensible, but as long as they are consistent
323   // between related frames it is OK.
324   const DecodedUleb128 kExpectedStartTimestamp = 99;
325   const DecodedUleb128 kExpectedEndTimestamp = 123;
326   const DecodedUleb128 kExpectedNumSamplesToTrimAtEnd = 999;
327   const DecodedUleb128 kExpectedNumSamplesToTrimAtStart = 9999;
328   const DecodedUleb128 kL2SubstreamId = 1;
329   absl::flat_hash_map<DecodedUleb128, AudioElementWithData> audio_elements;
330   InitAudioElementWithLabelsAndLayers(
331       {{kMonoSubstreamId, {kMono}}, {kL2SubstreamId, {kL2}}},
332       {ChannelAudioLayerConfig::kLayoutMono,
333        ChannelAudioLayerConfig::kLayoutStereo},
334       audio_elements);
335   std::list<DecodedAudioFrame> decoded_audio_frames;
336   decoded_audio_frames.push_back(DecodedAudioFrame{
337       .substream_id = kMonoSubstreamId,
338       .start_timestamp = kExpectedStartTimestamp,
339       .end_timestamp = kExpectedEndTimestamp,
340       .samples_to_trim_at_end = kExpectedNumSamplesToTrimAtEnd,
341       .samples_to_trim_at_start = kExpectedNumSamplesToTrimAtStart,
342       .decoded_samples = {{0}},
343       .down_mixing_params = DownMixingParams()});
344   decoded_audio_frames.push_back(DecodedAudioFrame{
345       .substream_id = kL2SubstreamId,
346       .start_timestamp = kExpectedStartTimestamp,
347       .end_timestamp = kExpectedEndTimestamp,
348       .samples_to_trim_at_end = kExpectedNumSamplesToTrimAtEnd,
349       .samples_to_trim_at_start = kExpectedNumSamplesToTrimAtStart,
350       .decoded_samples = {{0}},
351       .down_mixing_params = DownMixingParams()});
352   const auto demixing_module =
353       DemixingModule::CreateForReconstruction(audio_elements);
354   ASSERT_THAT(demixing_module, IsOk());
355 
356   const auto id_to_labeled_decoded_frame =
357       demixing_module->DemixDecodedAudioSamples(decoded_audio_frames);
358   ASSERT_THAT(id_to_labeled_decoded_frame, IsOk());
359   ASSERT_TRUE(id_to_labeled_decoded_frame->contains(kAudioElementId));
360 
361   const auto& labeled_frame = id_to_labeled_decoded_frame->at(kAudioElementId);
362   EXPECT_EQ(labeled_frame.end_timestamp, kExpectedEndTimestamp);
363   EXPECT_EQ(labeled_frame.samples_to_trim_at_end,
364             kExpectedNumSamplesToTrimAtEnd);
365   EXPECT_EQ(labeled_frame.samples_to_trim_at_start,
366             kExpectedNumSamplesToTrimAtStart);
367 }
368 
TEST(DemixDecodedAudioSamples,OutputEchoesOriginalLabels)369 TEST(DemixDecodedAudioSamples, OutputEchoesOriginalLabels) {
370   absl::flat_hash_map<DecodedUleb128, AudioElementWithData> audio_elements;
371   InitAudioElementWithLabelsAndLayers(
372       {{kMonoSubstreamId, {kMono}}, {kL2SubstreamId, {kL2}}},
373       {ChannelAudioLayerConfig::kLayoutMono,
374        ChannelAudioLayerConfig::kLayoutStereo},
375       audio_elements);
376   std::list<DecodedAudioFrame> decoded_audio_frames;
377   decoded_audio_frames.push_back(
378       DecodedAudioFrame{.substream_id = kMonoSubstreamId,
379                         .start_timestamp = kStartTimestamp,
380                         .end_timestamp = kEndTimestamp,
381                         .samples_to_trim_at_end = kZeroSamplesToTrimAtEnd,
382                         .samples_to_trim_at_start = kZeroSamplesToTrimAtStart,
383                         .decoded_samples = {{1}, {2}, {3}},
384                         .down_mixing_params = DownMixingParams()});
385   decoded_audio_frames.push_back(
386       DecodedAudioFrame{.substream_id = kL2SubstreamId,
387                         .start_timestamp = kStartTimestamp,
388                         .end_timestamp = kEndTimestamp,
389                         .samples_to_trim_at_end = kZeroSamplesToTrimAtEnd,
390                         .samples_to_trim_at_start = kZeroSamplesToTrimAtStart,
391                         .decoded_samples = {{9}, {10}, {11}},
392                         .down_mixing_params = DownMixingParams()});
393   const auto demixing_module =
394       DemixingModule::CreateForReconstruction(audio_elements);
395   ASSERT_THAT(demixing_module, IsOk());
396 
397   IdLabeledFrameMap unused_id_labeled_frame;
398   const auto id_to_labeled_decoded_frame =
399       demixing_module->DemixDecodedAudioSamples(decoded_audio_frames);
400   ASSERT_THAT(id_to_labeled_decoded_frame, IsOk());
401   ASSERT_TRUE(id_to_labeled_decoded_frame->contains(kAudioElementId));
402 
403   // Examine the demixed frame.
404   const auto& labeled_frame = id_to_labeled_decoded_frame->at(kAudioElementId);
405   constexpr std::array<int32_t, 3> kExpectedMonoSamples = {1, 2, 3};
406   constexpr std::array<int32_t, 3> kExpectedL2Samples = {9, 10, 11};
407   EXPECT_THAT(
408       labeled_frame.label_to_samples.at(kMono),
409       Pointwise(InternalSampleMatchesIntegralSample(), kExpectedMonoSamples));
410   EXPECT_THAT(
411       labeled_frame.label_to_samples.at(kL2),
412       Pointwise(InternalSampleMatchesIntegralSample(), kExpectedL2Samples));
413 }
414 
TEST(DemixDecodedAudioSamples,OutputHasReconstructedLayers)415 TEST(DemixDecodedAudioSamples, OutputHasReconstructedLayers) {
416   absl::flat_hash_map<DecodedUleb128, AudioElementWithData> audio_elements;
417 
418   InitAudioElementWithLabelsAndLayers(
419       {{kMonoSubstreamId, {kMono}}, {kL2SubstreamId, {kL2}}},
420       {ChannelAudioLayerConfig::kLayoutMono,
421        ChannelAudioLayerConfig::kLayoutStereo},
422       audio_elements);
423   std::list<DecodedAudioFrame> decoded_audio_frames;
424   decoded_audio_frames.push_back(
425       DecodedAudioFrame{.substream_id = kMonoSubstreamId,
426                         .start_timestamp = kStartTimestamp,
427                         .end_timestamp = kEndTimestamp,
428                         .samples_to_trim_at_end = kZeroSamplesToTrimAtEnd,
429                         .samples_to_trim_at_start = kZeroSamplesToTrimAtStart,
430                         .decoded_samples = {{750}},
431                         .down_mixing_params = DownMixingParams()});
432   decoded_audio_frames.push_back(
433       DecodedAudioFrame{.substream_id = kL2SubstreamId,
434                         .start_timestamp = kStartTimestamp,
435                         .end_timestamp = kEndTimestamp,
436                         .samples_to_trim_at_end = kZeroSamplesToTrimAtEnd,
437                         .samples_to_trim_at_start = kZeroSamplesToTrimAtStart,
438                         .decoded_samples = {{1000}},
439                         .down_mixing_params = DownMixingParams()});
440   const auto demixing_module =
441       DemixingModule::CreateForReconstruction(audio_elements);
442   ASSERT_THAT(demixing_module, IsOk());
443 
444   const auto id_to_labeled_decoded_frame =
445       demixing_module->DemixDecodedAudioSamples(decoded_audio_frames);
446   ASSERT_THAT(id_to_labeled_decoded_frame, IsOk());
447   ASSERT_TRUE(id_to_labeled_decoded_frame->contains(kAudioElementId));
448 
449   // Examine the demixed frame.
450   const auto& labeled_frame = id_to_labeled_decoded_frame->at(kAudioElementId);
451   // D_R2 =  M - (L2 - 6 dB)  + 6 dB.
452   EXPECT_THAT(labeled_frame.label_to_samples.at(kDemixedR2),
453               Pointwise(InternalSampleMatchesIntegralSample(), {500}));
454 }
455 
TEST(DemixDecodedAudioSamples,OutputContainsReconGainAndLayerInfo)456 TEST(DemixDecodedAudioSamples, OutputContainsReconGainAndLayerInfo) {
457   absl::flat_hash_map<DecodedUleb128, AudioElementWithData> audio_elements;
458   InitAudioElementWithLabelsAndLayers(
459       {{kMonoSubstreamId, {kMono}}, {kL2SubstreamId, {kL2}}},
460       {ChannelAudioLayerConfig::kLayoutMono,
461        ChannelAudioLayerConfig::kLayoutStereo},
462       audio_elements);
463   std::list<DecodedAudioFrame> decoded_audio_frames;
464   ReconGainInfoParameterData recon_gain_info_parameter_data;
465   recon_gain_info_parameter_data.recon_gain_elements.push_back(ReconGainElement{
466       .recon_gain_flag = DecodedUleb128(1), .recon_gain = kReconGainValues});
467   decoded_audio_frames.push_back(DecodedAudioFrame{
468       .substream_id = kMonoSubstreamId,
469       .start_timestamp = kStartTimestamp,
470       .end_timestamp = kEndTimestamp,
471       .samples_to_trim_at_end = kZeroSamplesToTrimAtEnd,
472       .samples_to_trim_at_start = kZeroSamplesToTrimAtStart,
473       .decoded_samples = {{0}},
474       .down_mixing_params = DownMixingParams(),
475       .recon_gain_info_parameter_data = recon_gain_info_parameter_data,
476       .audio_element_with_data = &audio_elements.at(kAudioElementId)});
477   decoded_audio_frames.push_back(DecodedAudioFrame{
478       .substream_id = kL2SubstreamId,
479       .start_timestamp = kStartTimestamp,
480       .end_timestamp = kEndTimestamp,
481       .samples_to_trim_at_end = kZeroSamplesToTrimAtEnd,
482       .samples_to_trim_at_start = kZeroSamplesToTrimAtStart,
483       .decoded_samples = {{0}},
484       .down_mixing_params = DownMixingParams(),
485       .recon_gain_info_parameter_data = recon_gain_info_parameter_data,
486       .audio_element_with_data = &audio_elements.at(kAudioElementId)});
487   const auto demixing_module =
488       DemixingModule::CreateForReconstruction(audio_elements);
489   ASSERT_THAT(demixing_module, IsOk());
490   const auto id_to_labeled_decoded_frame =
491       demixing_module->DemixDecodedAudioSamples(decoded_audio_frames);
492   ASSERT_THAT(id_to_labeled_decoded_frame, IsOk());
493   ASSERT_TRUE(id_to_labeled_decoded_frame->contains(kAudioElementId));
494 
495   const auto& labeled_frame = id_to_labeled_decoded_frame->at(kAudioElementId);
496   EXPECT_TRUE(labeled_frame.label_to_samples.contains(kL2));
497   EXPECT_TRUE(labeled_frame.label_to_samples.contains(kMono));
498   EXPECT_TRUE(labeled_frame.label_to_samples.contains(kDemixedR2));
499 
500   EXPECT_EQ(
501       labeled_frame.recon_gain_info_parameter_data.recon_gain_elements.size(),
502       1);
503   const auto& recon_gain_element =
504       labeled_frame.recon_gain_info_parameter_data.recon_gain_elements.at(0);
505   ASSERT_TRUE(recon_gain_element.has_value());
506   EXPECT_EQ(recon_gain_element->recon_gain_flag, DecodedUleb128(1));
507   EXPECT_THAT(recon_gain_element->recon_gain,
508               testing::ElementsAreArray(kReconGainValues));
509   EXPECT_EQ(labeled_frame.loudspeaker_layout_per_layer.size(), 2);
510   EXPECT_THAT(labeled_frame.loudspeaker_layout_per_layer,
511               testing::ElementsAre(ChannelAudioLayerConfig::kLayoutMono,
512                                    ChannelAudioLayerConfig::kLayoutStereo));
513 }
514 
515 class DemixingModuleTestBase {
516  public:
DemixingModuleTestBase()517   DemixingModuleTestBase() {
518     audio_frame_metadata_.set_audio_element_id(kAudioElementId);
519   }
520 
CreateDemixingModuleExpectOk()521   void CreateDemixingModuleExpectOk() {
522     iamf_tools_cli_proto::UserMetadata user_metadata;
523     *user_metadata.add_audio_frame_metadata() = audio_frame_metadata_;
524     audio_elements_.emplace(
525         kAudioElementId,
526         AudioElementWithData{
527             .obu = AudioElementObu(ObuHeader(), kAudioElementId,
528                                    AudioElementObu::kAudioElementChannelBased,
529                                    /*reserved=*/0,
530                                    /*codec_config_id=*/0),
531             .substream_id_to_labels = substream_id_to_labels_,
532         });
533     const absl::StatusOr<absl::flat_hash_map<
534         DecodedUleb128, DemixingModule::DownmixingAndReconstructionConfig>>
535         audio_element_id_to_demixing_metadata =
536             CreateAudioElementIdToDemixingMetadata(user_metadata,
537                                                    audio_elements_);
538     ASSERT_THAT(audio_element_id_to_demixing_metadata.status(), IsOk());
539     auto demixing_module = DemixingModule::CreateForDownMixingAndReconstruction(
540         std::move(audio_element_id_to_demixing_metadata.value()));
541     ASSERT_THAT(demixing_module, IsOk());
542     demixing_module_.emplace(*std::move(demixing_module));
543   }
544 
TestCreateDemixingModule(int expected_number_of_down_mixers)545   void TestCreateDemixingModule(int expected_number_of_down_mixers) {
546     CreateDemixingModuleExpectOk();
547     const std::list<Demixer>* down_mixers = nullptr;
548     const std::list<Demixer>* demixers = nullptr;
549 
550     ASSERT_THAT(demixing_module_->GetDownMixers(kAudioElementId, down_mixers),
551                 IsOk());
552     ASSERT_THAT(demixing_module_->GetDemixers(kAudioElementId, demixers),
553                 IsOk());
554     EXPECT_EQ(down_mixers->size(), expected_number_of_down_mixers);
555     EXPECT_EQ(demixers->size(), expected_number_of_down_mixers);
556   }
557 
558  protected:
ConfigureAudioFrameMetadata(absl::Span<const ChannelLabel::Label> labels)559   void ConfigureAudioFrameMetadata(
560       absl::Span<const ChannelLabel::Label> labels) {
561     for (const auto& label : labels) {
562       auto proto_label = ChannelLabelUtils::LabelToProto(label);
563       ASSERT_TRUE(proto_label.ok());
564       audio_frame_metadata_.add_channel_metadatas()->set_channel_label(
565           *proto_label);
566     }
567   }
568 
569   iamf_tools_cli_proto::AudioFrameObuMetadata audio_frame_metadata_;
570   absl::flat_hash_map<DecodedUleb128, AudioElementWithData> audio_elements_;
571   SubstreamIdLabelsMap substream_id_to_labels_;
572 
573   // Held in `std::optional` for delayed construction.
574   std::optional<DemixingModule> demixing_module_;
575 };
576 
577 class DownMixingModuleTest : public DemixingModuleTestBase,
578                              public ::testing::Test {
579  protected:
TestDownMixing(const DownMixingParams & down_mixing_params,int expected_number_of_down_mixers)580   void TestDownMixing(const DownMixingParams& down_mixing_params,
581                       int expected_number_of_down_mixers) {
582     TestCreateDemixingModule(expected_number_of_down_mixers);
583 
584     EXPECT_THAT(demixing_module_->DownMixSamplesToSubstreams(
585                     kAudioElementId, down_mixing_params,
586                     input_label_to_samples_, substream_id_to_substream_data_),
587                 IsOk());
588 
589     for (const auto& [substream_id, substream_data] :
590          substream_id_to_substream_data_) {
591       // Copy the output queue to a vector for comparison.
592       std::vector<std::vector<int32_t>> output_samples;
593       std::copy(substream_data.samples_obu.begin(),
594                 substream_data.samples_obu.end(),
595                 std::back_inserter(output_samples));
596       EXPECT_EQ(output_samples,
597                 substream_id_to_expected_samples_[substream_id]);
598     }
599   }
600 
ConfigureInputChannel(ChannelLabel::Label label,absl::Span<const int32_t> input_samples)601   void ConfigureInputChannel(ChannelLabel::Label label,
602                              absl::Span<const int32_t> input_samples) {
603     ConfigureAudioFrameMetadata({label});
604 
605     auto [iter, inserted] = input_label_to_samples_.emplace(
606         label, std::vector<InternalSampleType>(input_samples.size(), 0));
607     Int32ToInternalSampleType(input_samples, absl::MakeSpan(iter->second));
608     // This function should not be called with the same label twice.
609     ASSERT_TRUE(inserted);
610   }
611 
ConfigureOutputChannel(const std::list<ChannelLabel::Label> & requested_output_labels,const std::vector<std::vector<int32_t>> & expected_output_smples)612   void ConfigureOutputChannel(
613       const std::list<ChannelLabel::Label>& requested_output_labels,
614       const std::vector<std::vector<int32_t>>& expected_output_smples) {
615     // The substream ID itself does not matter. Generate a unique one.
616     const uint32_t substream_id = substream_id_to_labels_.size();
617 
618     substream_id_to_labels_[substream_id] = requested_output_labels;
619     substream_id_to_substream_data_[substream_id] = {.substream_id =
620                                                          substream_id};
621 
622     substream_id_to_expected_samples_[substream_id] = expected_output_smples;
623   }
624 
625   LabelSamplesMap input_label_to_samples_;
626 
627   absl::flat_hash_map<uint32_t, SubstreamData> substream_id_to_substream_data_;
628 
629   absl::flat_hash_map<uint32_t, std::vector<std::vector<int32_t>>>
630       substream_id_to_expected_samples_;
631 };
632 
TEST_F(DownMixingModuleTest,OneLayerStereoHasNoDownMixers)633 TEST_F(DownMixingModuleTest, OneLayerStereoHasNoDownMixers) {
634   ConfigureInputChannel(kL2, {});
635   ConfigureInputChannel(kR2, {});
636 
637   ConfigureOutputChannel({kL2, kR2}, {{}});
638 
639   TestCreateDemixingModule(0);
640 }
641 
TEST_F(DownMixingModuleTest,OneLayer7_1_4HasNoDownMixers)642 TEST_F(DownMixingModuleTest, OneLayer7_1_4HasNoDownMixers) {
643   // Initialize arguments for single layer 7.1.4.
644   ConfigureInputChannel(kL7, {});
645   ConfigureInputChannel(kR7, {});
646   ConfigureInputChannel(kCentre, {});
647   ConfigureInputChannel(kLFE, {});
648   ConfigureInputChannel(kLss7, {});
649   ConfigureInputChannel(kRss7, {});
650   ConfigureInputChannel(kLrs7, {});
651   ConfigureInputChannel(kRrs7, {});
652   ConfigureInputChannel(kLtf4, {});
653   ConfigureInputChannel(kRtf4, {});
654   ConfigureInputChannel(kLtb4, {});
655   ConfigureInputChannel(kRtb4, {});
656 
657   ConfigureOutputChannel({kCentre}, {{}});
658   ConfigureOutputChannel({kL7, kR7}, {});
659   ConfigureOutputChannel({kLss7, kRss7}, {});
660   ConfigureOutputChannel({kLrs7, kRrs7}, {});
661   ConfigureOutputChannel({kLtf4, kRtf4}, {});
662   ConfigureOutputChannel({kLtb4, kRtb4}, {});
663   ConfigureOutputChannel({kLFE}, {});
664 
665   TestCreateDemixingModule(0);
666 }
667 
TEST_F(DownMixingModuleTest,AmbisonicsHasNoDownMixers)668 TEST_F(DownMixingModuleTest, AmbisonicsHasNoDownMixers) {
669   ConfigureInputChannel(kA0, {});
670   ConfigureInputChannel(kA1, {});
671   ConfigureInputChannel(kA2, {});
672   ConfigureInputChannel(kA3, {});
673 
674   ConfigureOutputChannel({kA0}, {{}});
675   ConfigureOutputChannel({kA1}, {{}});
676   ConfigureOutputChannel({kA2}, {{}});
677   ConfigureOutputChannel({kA3}, {{}});
678 
679   TestCreateDemixingModule(0);
680 }
681 
TEST_F(DownMixingModuleTest,OneLayerStereo)682 TEST_F(DownMixingModuleTest, OneLayerStereo) {
683   ConfigureInputChannel(kL2, {0, 1, 2, 3});
684   ConfigureInputChannel(kR2, {100, 101, 102, 103});
685 
686   // Down-mix to stereo as the highest layer. The highest layer always matches
687   // the original input.
688   ConfigureOutputChannel({kL2, kR2}, {{0, 100}, {1, 101}, {2, 102}, {3, 103}});
689 
690   TestDownMixing({}, 0);
691 }
692 
TEST_F(DownMixingModuleTest,S2ToS1DownMixer)693 TEST_F(DownMixingModuleTest, S2ToS1DownMixer) {
694   ConfigureInputChannel(kL2, {0, 100, 500, 1000});
695   ConfigureInputChannel(kR2, {100, 0, 500, 500});
696 
697   // Down-mix to stereo as the highest layer. The highest layer always matches
698   // the original input.
699   ConfigureOutputChannel({kL2}, {{0}, {100}, {500}, {1000}});
700 
701   // Down-mix to mono as the lowest layer.
702   // M = (L2 - 6 dB) + (R2 - 6 dB).
703   ConfigureOutputChannel({kMono}, {{50}, {50}, {500}, {750}});
704 
705   TestDownMixing({}, 1);
706 }
707 
TEST_F(DownMixingModuleTest,S3ToS2DownMixer)708 TEST_F(DownMixingModuleTest, S3ToS2DownMixer) {
709   ConfigureInputChannel(kL3, {0, 100});
710   ConfigureInputChannel(kR3, {0, 100});
711   ConfigureInputChannel(kCentre, {100, 100});
712   ConfigureInputChannel(kLtf3, {99999, 99999});
713   ConfigureInputChannel(kRtf3, {99998, 99998});
714 
715   // Down-mix to 3.1.2 as the highest layer. The highest layer always matches
716   // the original input.
717   ConfigureOutputChannel({kCentre}, {{100}, {100}});
718   ConfigureOutputChannel({kLtf3, kRtf3}, {{99999, 99998}, {99999, 99998}});
719 
720   // Down-mix to stereo as the lowest layer.
721   // L2 = L3 + (C - 3 dB).
722   // R2 = R3 + (C - 3 dB).
723   ConfigureOutputChannel({kL2, kR2}, {{70, 70}, {170, 170}});
724 
725   TestDownMixing({}, 1);
726 }
727 
TEST_F(DownMixingModuleTest,S5ToS3ToS2DownMixer)728 TEST_F(DownMixingModuleTest, S5ToS3ToS2DownMixer) {
729   ConfigureInputChannel(kL5, {100});
730   ConfigureInputChannel(kR5, {200});
731   ConfigureInputChannel(kCentre, {1000});
732   ConfigureInputChannel(kLs5, {2000});
733   ConfigureInputChannel(kRs5, {3000});
734   ConfigureInputChannel(kLFE, {6});
735 
736   // Down-mix to 5.1 as the highest layer. The highest layer always matches the
737   // original input.
738   ConfigureOutputChannel({kCentre}, {{1000}});
739   ConfigureOutputChannel({kLs5, kRs5}, {{2000, 3000}});
740   ConfigureOutputChannel({kLFE}, {{6}});
741 
742   // Down-mix to stereo as the lowest layer.
743   // L3 = L5 + Ls5 * delta.
744   // L2 = L3 + (C - 3 dB).
745   ConfigureOutputChannel({kL2, kR2}, {{2221, 3028}});
746 
747   // Internally there is a down-mixer to L3/R3 then another for L2/R2.
748   TestDownMixing({.delta = .707}, 2);
749 }
750 
TEST_F(DownMixingModuleTest,S5ToS3ToDownMixer)751 TEST_F(DownMixingModuleTest, S5ToS3ToDownMixer) {
752   ConfigureInputChannel(kL5, {1000});
753   ConfigureInputChannel(kR5, {2000});
754   ConfigureInputChannel(kCentre, {3});
755   ConfigureInputChannel(kLs5, {4000});
756   ConfigureInputChannel(kRs5, {8000});
757   ConfigureInputChannel(kLtf2, {1000});
758   ConfigureInputChannel(kRtf2, {2000});
759   ConfigureInputChannel(kLFE, {8});
760 
761   // Down-mix to 5.1.2 as the highest layer. The highest layer always matches
762   // the original input.
763   ConfigureOutputChannel({kLs5, kRs5}, {{4000, 8000}});
764 
765   // Down-mix to 3.1.2 as the lowest layer.
766   // L3 = L5 + Ls5 * delta.
767   ConfigureOutputChannel({kL3, kR3}, {{3828, 7656}});
768   ConfigureOutputChannel({kCentre}, {{3}});
769   // Ltf3 = Ltf2 + Ls5 * w * delta.
770   ConfigureOutputChannel({kLtf3, kRtf3}, {{1707, 3414}});
771   ConfigureOutputChannel({kLFE}, {{8}});
772 
773   // Internally there is a down-mixer for the height and another for the
774   // surround.
775   TestDownMixing({.delta = .707, .w = 0.25}, 2);
776 }
777 
TEST_F(DownMixingModuleTest,T4ToT2DownMixer)778 TEST_F(DownMixingModuleTest, T4ToT2DownMixer) {
779   ConfigureInputChannel(kL5, {1});
780   ConfigureInputChannel(kR5, {2});
781   ConfigureInputChannel(kCentre, {3});
782   ConfigureInputChannel(kLs5, {4});
783   ConfigureInputChannel(kRs5, {5});
784   ConfigureInputChannel(kLtf4, {1000});
785   ConfigureInputChannel(kRtf4, {2000});
786   ConfigureInputChannel(kLtb4, {1000});
787   ConfigureInputChannel(kRtb4, {2000});
788   ConfigureInputChannel(kLFE, {10});
789 
790   // Down-mix to 5.1.4 as the highest layer. The highest layer always matches
791   // the original input.
792   ConfigureOutputChannel({kLtb4, kRtb4}, {{1000, 2000}});
793 
794   // Down-mix to 5.1.2 as the lowest layer.
795   ConfigureOutputChannel({kL5, kR5}, {{1, 2}});
796   ConfigureOutputChannel({kCentre}, {{3}});
797   ConfigureOutputChannel({kLs5, kRs5}, {{4, 5}});
798   // Ltf2 = Ltf4 + Ltb4 * gamma.
799   ConfigureOutputChannel({kLtf2, kRtf2}, {{1707, 3414}});
800   ConfigureOutputChannel({kLFE}, {{10}});
801 
802   TestDownMixing({.gamma = .707}, 1);
803 }
804 
TEST_F(DownMixingModuleTest,S7ToS5DownMixerWithoutT0)805 TEST_F(DownMixingModuleTest, S7ToS5DownMixerWithoutT0) {
806   ConfigureInputChannel(kL7, {1});
807   ConfigureInputChannel(kR7, {2});
808   ConfigureInputChannel(kCentre, {3});
809   ConfigureInputChannel(kLss7, {1000});
810   ConfigureInputChannel(kRss7, {2000});
811   ConfigureInputChannel(kLrs7, {3000});
812   ConfigureInputChannel(kRrs7, {4000});
813   ConfigureInputChannel(kLFE, {8});
814 
815   // Down-mix to 7.1.0 as the highest layer. The highest layer always matches
816   // the original input.
817   ConfigureOutputChannel({kLrs7, kRrs7}, {{3000, 4000}});
818 
819   // Down-mix to 5.1.0 as the lowest layer.
820   ConfigureOutputChannel({kL5, kR5}, {{1, 2}});
821   ConfigureOutputChannel({kCentre}, {{3}});
822   // Ls5 = Lss7 * alpha + Lrs7 * beta.
823   ConfigureOutputChannel({kLs5, kRs5}, {{3598, 5464}});
824   ConfigureOutputChannel({kLFE}, {{8}});
825 
826   TestDownMixing({.alpha = 1, .beta = .866}, 1);
827 }
828 
TEST_F(DownMixingModuleTest,S7ToS5DownMixerWithT2)829 TEST_F(DownMixingModuleTest, S7ToS5DownMixerWithT2) {
830   ConfigureInputChannel(kL7, {1});
831   ConfigureInputChannel(kR7, {2});
832   ConfigureInputChannel(kCentre, {3});
833   ConfigureInputChannel(kLss7, {1000});
834   ConfigureInputChannel(kRss7, {2000});
835   ConfigureInputChannel(kLrs7, {3000});
836   ConfigureInputChannel(kRrs7, {4000});
837   ConfigureInputChannel(kLtf2, {8});
838   ConfigureInputChannel(kRtf2, {9});
839   ConfigureInputChannel(kLFE, {10});
840 
841   // Down-mix to 7.1.2 as the highest layer. The highest layer always matches
842   // the original input.
843   ConfigureOutputChannel({kLrs7, kRrs7}, {{3000, 4000}});
844 
845   // Down-mix to 5.1.2 as the lowest layer.
846   ConfigureOutputChannel({kL5, kR5}, {{1, 2}});
847   ConfigureOutputChannel({kCentre}, {{3}});
848   // Ls5 = Lss7 * alpha + Lrs7 * beta.
849   ConfigureOutputChannel({kLs5, kRs5}, {{3598, 5464}});
850   ConfigureOutputChannel({kLtf2, kRtf2}, {{8, 9}});
851   ConfigureOutputChannel({kLFE}, {{10}});
852 
853   TestDownMixing({.alpha = 1, .beta = .866}, 1);
854 }
855 
TEST_F(DownMixingModuleTest,S7ToS5DownMixerWithT4)856 TEST_F(DownMixingModuleTest, S7ToS5DownMixerWithT4) {
857   ConfigureInputChannel(kL7, {1});
858   ConfigureInputChannel(kR7, {2});
859   ConfigureInputChannel(kCentre, {3});
860   ConfigureInputChannel(kLss7, {1000});
861   ConfigureInputChannel(kRss7, {2000});
862   ConfigureInputChannel(kLrs7, {3000});
863   ConfigureInputChannel(kRrs7, {4000});
864   ConfigureInputChannel(kLtf4, {8});
865   ConfigureInputChannel(kRtf4, {9});
866   ConfigureInputChannel(kLtb4, {10});
867   ConfigureInputChannel(kRtb4, {11});
868   ConfigureInputChannel(kLFE, {12});
869 
870   // Down-mix to 7.1.4 as the highest layer. The highest layer always matches
871   // the original input.
872   ConfigureOutputChannel({kLrs7, kRrs7}, {{3000, 4000}});
873 
874   // Down-mix to 5.1.4 as the lowest layer.
875   ConfigureOutputChannel({kL5, kR5}, {{1, 2}});
876   ConfigureOutputChannel({kCentre}, {{3}});
877   // Ls5 = Lss7 * alpha + Lrs7 * beta.
878   ConfigureOutputChannel({kLs5, kRs5}, {{3598, 5464}});
879   ConfigureOutputChannel({kLtf4, kRtf4}, {{8, 9}});
880   ConfigureOutputChannel({kLtb4, kRtb4}, {{10, 11}});
881   ConfigureOutputChannel({kLFE}, {{12}});
882 
883   TestDownMixing({.alpha = 1, .beta = .866}, 1);
884 }
885 
TEST_F(DownMixingModuleTest,SixLayer7_1_4)886 TEST_F(DownMixingModuleTest, SixLayer7_1_4) {
887   ConfigureInputChannel(kL7, {1000});
888   ConfigureInputChannel(kR7, {2000});
889   ConfigureInputChannel(kCentre, {1000});
890   ConfigureInputChannel(kLss7, {1000});
891   ConfigureInputChannel(kRss7, {2000});
892   ConfigureInputChannel(kLrs7, {3000});
893   ConfigureInputChannel(kRrs7, {4000});
894   ConfigureInputChannel(kLtf4, {1000});
895   ConfigureInputChannel(kRtf4, {2000});
896   ConfigureInputChannel(kLtb4, {1000});
897   ConfigureInputChannel(kRtb4, {2000});
898   ConfigureInputChannel(kLFE, {12});
899 
900   // There are different paths to have six-layers, choose 7.1.2, 5.1.2, 3.1.2,
901   // stereo, mono to avoid dropping the height channels for as many steps as
902   // possible.
903 
904   // Down-mix to 7.1.4 as the sixth layer.
905   ConfigureOutputChannel({kLtb4, kRtb4}, {{1000, 2000}});
906 
907   // Down-mix to 7.1.2 as the fifth layer.
908   ConfigureOutputChannel({kLrs7, kRrs7}, {{3000, 4000}});
909 
910   // Down-mix to 5.1.2 as the fourth layer.
911   // Ls5 = Lss7 * alpha + Lrs7 * beta.
912   ConfigureOutputChannel({kLs5, kRs5}, {{3598, 5464}});
913 
914   // Down-mix to 3.1.2 as the third layer.
915   ConfigureOutputChannel({kCentre}, {{1000}});
916   // Ltf2 = Ltf4 + Ltb4 * gamma.
917   // Ltf3 = Ltf2 + Ls5 * w * delta.
918   ConfigureOutputChannel({kLtf3, kRtf3}, {{2644, 4914}});
919   ConfigureOutputChannel({kLFE}, {{12}});
920 
921   // Down-mix to stereo as the second layer.
922   // L5 = L7.
923   // L3 = L5 + Ls5 * delta.
924   // L2 = L3 + (C - 3 dB).
925   ConfigureOutputChannel({kL2}, {{4822}});
926 
927   // Down=mix to mono as the first layer.
928   // R5 = R7.
929   // R3 = R5 + Rs5 * delta.
930   // R2 = R3 + (C - 3 dB).
931   // M = (L2 - 6 dB) + (R2 - 6 dB).
932   ConfigureOutputChannel({kMono}, {{6130}});
933 
934   TestDownMixing(
935       {.alpha = 1, .beta = .866, .gamma = .866, .delta = .866, .w = 0.25}, 6);
936 }
937 
938 class DemixingModuleTest : public DemixingModuleTestBase,
939                            public ::testing::Test {
940  public:
ConfigureLosslessAudioFrameAndDecodedAudioFrame(const std::list<ChannelLabel::Label> & labels,const std::vector<std::vector<int32_t>> & pcm_samples,DownMixingParams down_mixing_params={ .alpha = 1, .beta = .866, .gamma = .866, .delta = .866, .w = 0.25})941   void ConfigureLosslessAudioFrameAndDecodedAudioFrame(
942       const std::list<ChannelLabel::Label>& labels,
943       const std::vector<std::vector<int32_t>>& pcm_samples,
944       DownMixingParams down_mixing_params = {
945           .alpha = 1, .beta = .866, .gamma = .866, .delta = .866, .w = 0.25}) {
946     // The substream ID itself does not matter. Generate a unique one.
947     const DecodedUleb128 substream_id = substream_id_to_labels_.size();
948     substream_id_to_labels_[substream_id] = labels;
949 
950     // Configure a pair of audio frames and decoded audio frames. They share a
951     // lot of the same information for a lossless codec.
952     audio_frames_.push_back(AudioFrameWithData{
953         .obu = AudioFrameObu(ObuHeader(), substream_id, {}),
954         .start_timestamp = kStartTimestamp,
955         .end_timestamp = kEndTimestamp,
956         .pcm_samples = pcm_samples,
957         .down_mixing_params = down_mixing_params,
958     });
959 
960     decoded_audio_frames_.push_back(
961         DecodedAudioFrame{.substream_id = substream_id,
962                           .start_timestamp = kStartTimestamp,
963                           .end_timestamp = kEndTimestamp,
964                           .samples_to_trim_at_end = kZeroSamplesToTrimAtEnd,
965                           .samples_to_trim_at_start = kZeroSamplesToTrimAtStart,
966                           .decoded_samples = pcm_samples,
967                           .down_mixing_params = down_mixing_params});
968 
969     auto& expected_label_to_samples =
970         expected_id_to_labeled_decoded_frame_[kAudioElementId].label_to_samples;
971     // `raw_samples` is arranged in (time, channel axes). Arrange the samples
972     // associated with each channel by time. The demixing process never changes
973     // data for the input labels.
974     auto labels_iter = labels.begin();
975     for (int channel = 0; channel < labels.size(); ++channel) {
976       auto& samples_for_channel = expected_label_to_samples[*labels_iter];
977 
978       samples_for_channel.reserve(pcm_samples.size());
979       for (auto tick : pcm_samples) {
980         samples_for_channel.push_back(
981             Int32ToNormalizedFloatingPoint<InternalSampleType>(tick[channel]));
982       }
983       labels_iter++;
984     }
985   }
986 
ConfiguredExpectedDemixingChannelFrame(ChannelLabel::Label label,const std::vector<int32_t> & expected_demixed_samples)987   void ConfiguredExpectedDemixingChannelFrame(
988       ChannelLabel::Label label,
989       const std::vector<int32_t>& expected_demixed_samples) {
990     std::vector<InternalSampleType> expected_demixed_samples_as_internal_type;
991     expected_demixed_samples_as_internal_type.reserve(
992         expected_demixed_samples.size());
993     for (int32_t sample : expected_demixed_samples) {
994       expected_demixed_samples_as_internal_type.push_back(
995           Int32ToNormalizedFloatingPoint<InternalSampleType>(sample));
996     }
997 
998     // Configure the expected demixed channels. Typically the input `label`
999     // should have a "D_" prefix.
1000     expected_id_to_labeled_decoded_frame_[kAudioElementId]
1001         .label_to_samples[label] = expected_demixed_samples_as_internal_type;
1002   }
1003 
TestLosslessDemixing(int expected_number_of_down_mixers)1004   void TestLosslessDemixing(int expected_number_of_down_mixers) {
1005     TestCreateDemixingModule(expected_number_of_down_mixers);
1006 
1007     const auto id_to_labeled_decoded_frame =
1008         demixing_module_->DemixDecodedAudioSamples(decoded_audio_frames_);
1009     ASSERT_THAT(id_to_labeled_decoded_frame, IsOk());
1010     ASSERT_TRUE(id_to_labeled_decoded_frame->contains(kAudioElementId));
1011 
1012     // Check that the demixed samples have the correct values.
1013     const auto& actual_label_to_samples =
1014         id_to_labeled_decoded_frame->at(kAudioElementId).label_to_samples;
1015 
1016     const auto& expected_label_to_samples =
1017         expected_id_to_labeled_decoded_frame_[kAudioElementId].label_to_samples;
1018     EXPECT_EQ(actual_label_to_samples.size(), expected_label_to_samples.size());
1019     for (const auto& [label, samples] : actual_label_to_samples) {
1020       // Use `DoubleNear` with a tolerance because floating-point arithmetic
1021       // introduces errors larger than allowed by `DoubleEq`.
1022       constexpr double kErrorTolerance = 1e-14;
1023       EXPECT_THAT(samples, Pointwise(DoubleNear(kErrorTolerance),
1024                                      expected_label_to_samples.at(label)));
1025     }
1026 
1027     // Also, since this is lossless, we expect demixing the original samples
1028     // should give the same result.
1029     const auto id_to_labeled_frame =
1030         demixing_module_->DemixOriginalAudioSamples(audio_frames_);
1031     ASSERT_THAT(id_to_labeled_frame, IsOk());
1032     ASSERT_TRUE(id_to_labeled_frame->contains(kAudioElementId));
1033     EXPECT_EQ(id_to_labeled_frame->at(kAudioElementId).label_to_samples,
1034               actual_label_to_samples);
1035   }
1036 
1037  protected:
1038   std::list<AudioFrameWithData> audio_frames_;
1039   std::list<DecodedAudioFrame> decoded_audio_frames_;
1040 
1041   IdLabeledFrameMap expected_id_to_labeled_decoded_frame_;
1042 };  // namespace
1043 
TEST(DemixingModule,DemixingOriginalAudioSamplesSucceedsWithEmptyInputs)1044 TEST(DemixingModule, DemixingOriginalAudioSamplesSucceedsWithEmptyInputs) {
1045   const auto demixing_module =
1046       DemixingModule::CreateForDownMixingAndReconstruction({});
1047   ASSERT_THAT(demixing_module, IsOk());
1048 
1049   EXPECT_THAT(demixing_module->DemixOriginalAudioSamples({}),
1050               IsOkAndHolds(IsEmpty()));
1051 }
1052 
TEST(DemixingModule,DemixingDecodedAudioSamplesSucceedsWithEmptyInputs)1053 TEST(DemixingModule, DemixingDecodedAudioSamplesSucceedsWithEmptyInputs) {
1054   const auto demixing_module =
1055       DemixingModule::CreateForDownMixingAndReconstruction({});
1056   ASSERT_THAT(demixing_module, IsOk());
1057 
1058   EXPECT_THAT(demixing_module->DemixDecodedAudioSamples({}),
1059               IsOkAndHolds(IsEmpty()));
1060 }
1061 
TEST_F(DemixingModuleTest,AmbisonicsHasNoDemixers)1062 TEST_F(DemixingModuleTest, AmbisonicsHasNoDemixers) {
1063   ConfigureAudioFrameMetadata({kA0, kA1, kA2, kA3});
1064 
1065   ConfigureLosslessAudioFrameAndDecodedAudioFrame({kA0}, {{1}});
1066   ConfigureLosslessAudioFrameAndDecodedAudioFrame({kA1}, {{1}});
1067   ConfigureLosslessAudioFrameAndDecodedAudioFrame({kA2}, {{1}});
1068   ConfigureLosslessAudioFrameAndDecodedAudioFrame({kA3}, {{1}});
1069 
1070   TestLosslessDemixing(0);
1071 }
1072 
TEST_F(DemixingModuleTest,S1ToS2Demixer)1073 TEST_F(DemixingModuleTest, S1ToS2Demixer) {
1074   // The highest layer is stereo.
1075   ConfigureAudioFrameMetadata({kL2, kR2});
1076 
1077   // Mono is the lowest layer.
1078   ConfigureLosslessAudioFrameAndDecodedAudioFrame({kMono}, {{750}, {1500}});
1079   // Stereo is the next layer.
1080   ConfigureLosslessAudioFrameAndDecodedAudioFrame({kL2}, {{1000}, {2000}});
1081 
1082   // Demixing recovers kDemixedR2
1083   // D_R2 =  M - (L2 - 6 dB)  + 6 dB.
1084   ConfiguredExpectedDemixingChannelFrame(kDemixedR2, {500, 1000});
1085 
1086   TestLosslessDemixing(1);
1087 }
1088 
TEST_F(DemixingModuleTest,DemixOriginalAudioSamplesReturnsErrorIfAudioFrameIsMissingPcmSamples)1089 TEST_F(DemixingModuleTest,
1090        DemixOriginalAudioSamplesReturnsErrorIfAudioFrameIsMissingPcmSamples) {
1091   ConfigureAudioFrameMetadata({kL2, kR2});
1092   ConfigureLosslessAudioFrameAndDecodedAudioFrame({kMono}, {{750}, {1500}});
1093   ConfigureLosslessAudioFrameAndDecodedAudioFrame({kL2}, {{1000}, {2000}});
1094   IdLabeledFrameMap unused_id_to_labeled_frame, id_to_labeled_decoded_frame;
1095   TestCreateDemixingModule(1);
1096   // Destroy the raw samples.
1097   audio_frames_.back().pcm_samples = std::nullopt;
1098 
1099   EXPECT_THAT(demixing_module_->DemixOriginalAudioSamples(audio_frames_),
1100               Not(IsOk()));
1101 }
1102 
TEST_F(DemixingModuleTest,S2ToS3Demixer)1103 TEST_F(DemixingModuleTest, S2ToS3Demixer) {
1104   // The highest layer is 3.1.2.
1105   ConfigureAudioFrameMetadata({kL3, kR3, kCentre, kLtf3, kRtf3});
1106 
1107   // Stereo is the lowest layer.
1108   ConfigureLosslessAudioFrameAndDecodedAudioFrame({kL2, kR2},
1109                                                   {{70, 70}, {1700, 1700}});
1110 
1111   // 3.1.2 as the next layer.
1112   ConfigureLosslessAudioFrameAndDecodedAudioFrame({kCentre}, {{2000}, {1000}});
1113   ConfigureLosslessAudioFrameAndDecodedAudioFrame(
1114       {kLtf3, kRtf3}, {{99999, 99998}, {99999, 99998}});
1115 
1116   // L3/R3 get demixed from the lower layers.
1117   // L3 = L2 - (C - 3 dB).
1118   // R3 = R2 - (C - 3 dB).
1119   ConfiguredExpectedDemixingChannelFrame(kDemixedL3, {-1344, 993});
1120   ConfiguredExpectedDemixingChannelFrame(kDemixedR3, {-1344, 993});
1121 
1122   TestLosslessDemixing(1);
1123 }
1124 
TEST_F(DemixingModuleTest,S3ToS5AndTf2ToT2Demixers)1125 TEST_F(DemixingModuleTest, S3ToS5AndTf2ToT2Demixers) {
1126   // Adding a (valid) layer on top of 3.1.2 will always result in both S3ToS5
1127   // and Tf2ToT2 demixers.
1128   // The highest layer is 5.1.2.
1129   ConfigureAudioFrameMetadata({kL5, kR5, kCentre, kLtf2, kRtf2});
1130 
1131   const DownMixingParams kDownMixingParams = {.delta = .866, .w = 0.25};
1132 
1133   // 3.1.2 is the lowest layer.
1134   ConfigureLosslessAudioFrameAndDecodedAudioFrame({kL3, kR3}, {{18660, 28660}},
1135                                                   kDownMixingParams);
1136   ConfigureLosslessAudioFrameAndDecodedAudioFrame({kCentre}, {{100}},
1137                                                   kDownMixingParams);
1138   ConfigureLosslessAudioFrameAndDecodedAudioFrame(
1139       {kLtf3, kRtf3}, {{1000, 2000}}, kDownMixingParams);
1140 
1141   // 5.1.2 as the next layer.
1142   ConfigureLosslessAudioFrameAndDecodedAudioFrame({kL5, kR5}, {{10000, 20000}},
1143                                                   kDownMixingParams);
1144 
1145   // S3ToS5: Ls5/Rs5 get demixed from the lower layers.
1146   // Ls5 = (1 / delta) * (L3 - L5).
1147   // Rs5 = (1 / delta) * (R3 - R5).
1148   ConfiguredExpectedDemixingChannelFrame(kDemixedLs5, {10000});
1149   ConfiguredExpectedDemixingChannelFrame(kDemixedRs5, {10000});
1150 
1151   // Tf2ToT2: Ltf2/Rtf2 get demixed from the lower layers.
1152   // Ltf2 = Ltf3 - w * (L3 - L5).
1153   // Rtf2 = Rtf3 - w * (R3 - R5).
1154   ConfiguredExpectedDemixingChannelFrame(kDemixedLtf2, {-1165});
1155   ConfiguredExpectedDemixingChannelFrame(kDemixedRtf2, {-165});
1156 
1157   TestLosslessDemixing(2);
1158 }
1159 
TEST_F(DemixingModuleTest,S5ToS7Demixer)1160 TEST_F(DemixingModuleTest, S5ToS7Demixer) {
1161   // The highest layer is 7.1.0.
1162   ConfigureAudioFrameMetadata({kL7, kR7, kCentre, kLss7, kRss7, kLrs7, kRrs7});
1163 
1164   const DownMixingParams kDownMixingParams = {.alpha = 0.866, .beta = .866};
1165 
1166   // 5.1.0 is the lowest layer.
1167   ConfigureLosslessAudioFrameAndDecodedAudioFrame({kL5, kR5}, {{100, 100}},
1168                                                   kDownMixingParams);
1169   ConfigureLosslessAudioFrameAndDecodedAudioFrame({kLs5, kRs5}, {{7794, 7794}},
1170                                                   kDownMixingParams);
1171   ConfigureLosslessAudioFrameAndDecodedAudioFrame({kCentre}, {{100}},
1172                                                   kDownMixingParams);
1173 
1174   // 7.1.0 as the next layer.
1175   ConfigureLosslessAudioFrameAndDecodedAudioFrame(
1176       {kLss7, kRss7}, {{1000, 2000}}, kDownMixingParams);
1177 
1178   // L7/R7 get demixed from the lower layers.
1179   // L7 = R5.
1180   // R7 = R5.
1181   ConfiguredExpectedDemixingChannelFrame(kDemixedL7, {100});
1182   ConfiguredExpectedDemixingChannelFrame(kDemixedR7, {100});
1183 
1184   // Lrs7/Rrs7 get demixed from the lower layers.
1185   // Lrs7 = (1 / beta) * (Ls5 - alpha * Lss7).
1186   // Rrs7 = (1 / beta) * (Rs5 - alpha * Rss7).
1187   ConfiguredExpectedDemixingChannelFrame(kDemixedLrs7, {8000});
1188   ConfiguredExpectedDemixingChannelFrame(kDemixedRrs7, {7000});
1189 
1190   TestLosslessDemixing(1);
1191 }
1192 
TEST_F(DemixingModuleTest,T2ToT4Demixer)1193 TEST_F(DemixingModuleTest, T2ToT4Demixer) {
1194   // The highest layer is 5.1.4.
1195   ConfigureAudioFrameMetadata({kL5, kR5, kCentre, kLtf4, kRtf4});
1196 
1197   const DownMixingParams kDownMixingParams = {.gamma = .866};
1198 
1199   // 5.1.2 is the lowest layer.
1200   ConfigureLosslessAudioFrameAndDecodedAudioFrame({kL5, kR5}, {{100, 100}},
1201                                                   kDownMixingParams);
1202   ConfigureLosslessAudioFrameAndDecodedAudioFrame({kLs5, kRs5}, {{100, 100}},
1203                                                   kDownMixingParams);
1204   ConfigureLosslessAudioFrameAndDecodedAudioFrame({kCentre}, {{100}},
1205                                                   kDownMixingParams);
1206   ConfigureLosslessAudioFrameAndDecodedAudioFrame(
1207       {kLtf2, kRtf2}, {{8660, 17320}}, kDownMixingParams);
1208 
1209   // 5.1.4 as the next layer.
1210   ConfigureLosslessAudioFrameAndDecodedAudioFrame({kLtf4, kRtf4}, {{866, 1732}},
1211                                                   kDownMixingParams);
1212 
1213   // Ltb4/Rtb4 get demixed from the lower layers.
1214   // Ltb4 = (1 / gamma) * (Ltf2 - Ltf4).
1215   // Ttb4 = (1 / gamma) * (Ttf2 - Rtf4).
1216   ConfiguredExpectedDemixingChannelFrame(kDemixedLtb4, {9000});
1217   ConfiguredExpectedDemixingChannelFrame(kDemixedRtb4, {18000});
1218 
1219   TestLosslessDemixing(1);
1220 }
1221 
1222 }  // namespace
1223 }  // namespace iamf_tools
1224