1 /*
2  * Copyright (c) 2023, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 3-Clause Clear License
5  * and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
6  * License was not distributed with this source code in the LICENSE file, you
7  * can obtain it at www.aomedia.org/license/software-license/bsd-3-c-c. If the
8  * Alliance for Open Media Patent License 1.0 was not distributed with this
9  * source code in the PATENTS file, you can obtain it at
10  * www.aomedia.org/license/patent.
11  */
12 #include "iamf/cli/recon_gain_generator.h"
13 
14 #include <cmath>
15 #include <vector>
16 
17 #include "absl/base/no_destructor.h"
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/log/log.h"
20 #include "absl/status/status.h"
21 #include "iamf/cli/channel_label.h"
22 #include "iamf/cli/demixing_module.h"
23 #include "iamf/common/utils/macros.h"
24 #include "iamf/common/utils/map_utils.h"
25 #include "iamf/obu/types.h"
26 
27 namespace iamf_tools {
28 
29 namespace {
30 
31 // Returns the Root Mean Square (RMS) power of input `samples`.
ComputeSignalPower(const std::vector<InternalSampleType> & samples)32 double ComputeSignalPower(const std::vector<InternalSampleType>& samples) {
33   double mean_square = 0.0;
34   const double scale = 1.0 / static_cast<double>(samples.size());
35   for (const auto s : samples) {
36     mean_square += scale * s * s;
37   }
38   return std::sqrt(mean_square);
39 }
40 
41 // Find relevant samples. E.g. Computation of kDemixedLrs7 uses kLs5 and kLss7.
42 // Spec says "relevant mixed channel of the down-mixed audio for CL #i-1." So
43 // Level Mk is the signal power or kLs5. kLss7 is from CL #i and does not
44 // contribute to Level Mk.
FindRelevantMixedSamples(const bool additional_logging,ChannelLabel::Label label,const LabelSamplesMap & label_to_samples,const std::vector<InternalSampleType> ** relevant_mixed_samples)45 absl::Status FindRelevantMixedSamples(
46     const bool additional_logging, ChannelLabel::Label label,
47     const LabelSamplesMap& label_to_samples,
48     const std::vector<InternalSampleType>** relevant_mixed_samples) {
49   using enum ChannelLabel::Label;
50   static const absl::NoDestructor<
51       absl::flat_hash_map<ChannelLabel::Label, ChannelLabel::Label>>
52       kLabelToRelevantMixedLabel({{kDemixedL7, kL5},
53                                   {kDemixedR7, kR5},
54                                   {kDemixedLrs7, kLs5},
55                                   {kDemixedRrs7, kRs5},
56                                   {kDemixedLtb4, kLtf2},
57                                   {kDemixedRtb4, kRtf2},
58                                   {kDemixedL5, kL3},
59                                   {kDemixedR5, kR3},
60                                   {kDemixedLs5, kL3},
61                                   {kDemixedRs5, kR3},
62                                   {kDemixedLtf2, kLtf3},
63                                   {kDemixedRtf2, kRtf3},
64                                   {kDemixedL3, kL2},
65                                   {kDemixedR3, kR2},
66                                   {kDemixedR2, kMono}});
67 
68   ChannelLabel::Label relevant_mixed_label;
69   RETURN_IF_NOT_OK(
70       CopyFromMap(*kLabelToRelevantMixedLabel, label,
71                   "`relevant_mixed_label` for demixed `ChannelLabel::Label`",
72                   relevant_mixed_label));
73 
74   LOG_IF(INFO, additional_logging)
75       << "Relevant mixed samples has label: " << relevant_mixed_label;
76   return DemixingModule::FindSamplesOrDemixedSamples(
77       relevant_mixed_label, label_to_samples, relevant_mixed_samples);
78 }
79 
80 }  // namespace
81 
ComputeReconGain(ChannelLabel::Label label,const LabelSamplesMap & label_to_samples,const LabelSamplesMap & label_to_decoded_samples,const bool additional_logging,double & recon_gain)82 absl::Status ReconGainGenerator::ComputeReconGain(
83     ChannelLabel::Label label, const LabelSamplesMap& label_to_samples,
84     const LabelSamplesMap& label_to_decoded_samples,
85     const bool additional_logging, double& recon_gain) {
86   // Gather information about the original samples.
87   const std::vector<InternalSampleType>* original_samples;
88   RETURN_IF_NOT_OK(DemixingModule::FindSamplesOrDemixedSamples(
89       label, label_to_samples, &original_samples));
90   LOG_IF(INFO, additional_logging)
91       << "[" << label
92       << "] original_samples.size()= " << original_samples->size();
93 
94   // Level Ok in the Spec.
95   const double original_power = ComputeSignalPower(*original_samples);
96 
97   // TODO(b/289064747): Investigate if the recon gain mismatches are resolved
98   //                    after we switched to representing data in [-1, +1].
99   // If 10*log10(level Ok / maxL^2) is less than the first threshold value
100   // (e.g. -80dB), Recon_Gain (k, i) = 0. Where, maxL = 32767 for 16bits.
101   // In this codebase we represent the `InternalSampleType` as a `double` in the
102   // range of [-1, +1], so we use maxL = 1.0 instead.
103   constexpr InternalSampleType kMaxLSquared = 1.0 * 1.0;
104   const double original_power_db = 10 * log10(original_power / kMaxLSquared);
105   LOG_IF(INFO, additional_logging) << "Level OK (dB) " << original_power_db;
106   if (original_power_db < -80) {
107     recon_gain = 0;
108     return absl::OkStatus();
109   }
110 
111   // Gather information about mixed samples.
112   const std::vector<InternalSampleType>* relevant_mixed_samples;
113   RETURN_IF_NOT_OK(FindRelevantMixedSamples(
114       additional_logging, label, label_to_samples, &relevant_mixed_samples));
115   LOG_IF(INFO, additional_logging)
116       << "[" << label
117       << "] relevant_mixed_samples.size()= " << relevant_mixed_samples->size();
118 
119   // Level Mk in the Spec.
120   const double relevant_mixed_power =
121       ComputeSignalPower(*relevant_mixed_samples);
122   const double mixed_power_db = 10 * log10(relevant_mixed_power / kMaxLSquared);
123   LOG_IF(INFO, additional_logging) << "Level MK (dB) " << mixed_power_db;
124 
125   // If 10*log10(level Ok / level Mk ) is less than the second threshold
126   // value (e.g. -6dB), Recon_Gain (k, i) is set to the value which makes
127   // level Ok = Recon_Gain (k, i)^2 x level Dk.
128   double original_mixed_ratio_db =
129       10 * log10(original_power / relevant_mixed_power);
130   LOG_IF(INFO, additional_logging)
131       << "Level Ok (dB) / Level Mk (dB) " << original_mixed_ratio_db;
132 
133   // Otherwise, Recon_Gain (k, i) = 1.
134   if (original_mixed_ratio_db >= -6) {
135     recon_gain = 1;
136     return absl::OkStatus();
137   }
138 
139   // Gather information about the demixed samples.
140   const std::vector<InternalSampleType>* demixed_samples;
141   RETURN_IF_NOT_OK(DemixingModule::FindSamplesOrDemixedSamples(
142       label, label_to_decoded_samples, &demixed_samples));
143   LOG_IF(INFO, additional_logging)
144       << "[" << label
145       << "] demixed_samples.size()= " << demixed_samples->size();
146 
147   // Level Dk in the Spec.
148   const double demixed_power = ComputeSignalPower(*demixed_samples);
149 
150   // Set recon gain to the value implied by the spec.
151   double demixed_power_ratio_db = 10 * log10(demixed_power / mixed_power_db);
152   LOG_IF(INFO, additional_logging)
153       << "Level DK (dB) " << demixed_power_ratio_db;
154   recon_gain = std::sqrt(original_power / demixed_power);
155 
156   return absl::OkStatus();
157 }
158 
159 }  // namespace iamf_tools
160