1 /*
2  * Copyright (c) 2024, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 3-Clause Clear License
5  * and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
6  * License was not distributed with this source code in the LICENSE file, you
7  * can obtain it at www.aomedia.org/license/software-license/bsd-3-c-c. If the
8  * Alliance for Open Media Patent License 1.0 was not distributed with this
9  * source code in the PATENTS file, you can obtain it at
10  * www.aomedia.org/license/patent.
11  */
12 
13 #include "iamf/cli/wav_sample_provider.h"
14 
15 #include <cstddef>
16 #include <cstdint>
17 #include <filesystem>
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/log/log.h"
24 #include "absl/status/status.h"
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/string_view.h"
27 #include "iamf/cli/audio_element_with_data.h"
28 #include "iamf/cli/demixing_module.h"
29 #include "iamf/cli/proto/audio_frame.pb.h"
30 #include "iamf/cli/proto_conversion/channel_label_utils.h"
31 #include "iamf/cli/wav_reader.h"
32 #include "iamf/common/utils/macros.h"
33 #include "iamf/common/utils/numeric_utils.h"
34 #include "iamf/common/utils/validation_utils.h"
35 #include "iamf/obu/codec_config.h"
36 #include "iamf/obu/types.h"
37 #include "src/google/protobuf/repeated_ptr_field.h"
38 
39 namespace iamf_tools {
40 
41 namespace {
42 
FillChannelIdsFromDeprecatedChannelIds(const iamf_tools_cli_proto::AudioFrameObuMetadata & audio_frame_metadata,std::vector<uint32_t> & channel_ids)43 absl::Status FillChannelIdsFromDeprecatedChannelIds(
44     const iamf_tools_cli_proto::AudioFrameObuMetadata& audio_frame_metadata,
45     std::vector<uint32_t>& channel_ids) {
46   if (audio_frame_metadata.channel_ids_size() !=
47       audio_frame_metadata.channel_labels_size()) {
48     return absl::InvalidArgumentError(
49         absl::StrCat("#channel IDs and #channel labels differ: (",
50                      audio_frame_metadata.channel_ids_size(), " vs ",
51                      audio_frame_metadata.channel_labels_size(), ")"));
52   }
53 
54   // Precompute the channel IDs for the audio element.
55   channel_ids.reserve(audio_frame_metadata.channel_ids_size());
56   for (const uint32_t channel_id : audio_frame_metadata.channel_ids()) {
57     channel_ids.push_back(channel_id);
58   }
59 
60   return absl::OkStatus();
61 }
62 
FillChannelIdsFromChannelMetadatas(const iamf_tools_cli_proto::AudioFrameObuMetadata & audio_frame_metadata,std::vector<uint32_t> & channel_ids)63 absl::Status FillChannelIdsFromChannelMetadatas(
64     const iamf_tools_cli_proto::AudioFrameObuMetadata& audio_frame_metadata,
65     std::vector<uint32_t>& channel_ids) {
66   if (!audio_frame_metadata.channel_ids().empty()) {
67     return absl::InvalidArgumentError(
68         "Please fully upgrade to `channel_metadatas`. Leave `channel_ids` "
69         "empty");
70   }
71   // Collect the channel IDs from the `channel_metadatas` field.
72   channel_ids.reserve(audio_frame_metadata.channel_metadatas().size());
73   for (const auto& channel_metadata :
74        audio_frame_metadata.channel_metadatas()) {
75     channel_ids.push_back(channel_metadata.channel_id());
76   }
77   return absl::OkStatus();
78 }
79 
FillChannelIdsAndLabels(const iamf_tools_cli_proto::AudioFrameObuMetadata & audio_frame_metadata,std::vector<uint32_t> & channel_ids,std::vector<ChannelLabel::Label> & channel_labels)80 absl::Status FillChannelIdsAndLabels(
81     const iamf_tools_cli_proto::AudioFrameObuMetadata& audio_frame_metadata,
82     std::vector<uint32_t>& channel_ids,
83     std::vector<ChannelLabel::Label>& channel_labels) {
84   // Collect the channel IDs.
85   if (!audio_frame_metadata.channel_metadatas().empty()) {
86     RETURN_IF_NOT_OK(
87         FillChannelIdsFromChannelMetadatas(audio_frame_metadata, channel_ids));
88   } else {
89     RETURN_IF_NOT_OK(FillChannelIdsFromDeprecatedChannelIds(
90         audio_frame_metadata, channel_ids));
91   }
92   if (!ValidateUnique(channel_ids.begin(), channel_ids.end(), "channel ids")
93            .ok()) {
94     // OK. The user is claiming some channel IDs are shared between labels.
95     // This is strange, but permitted.
96     LOG(WARNING) << "Usually channel labels should be unique. Did you use "
97                     "the same channel ID for different channels?";
98   }
99 
100   // Precompute the internal `ChannelLabel::Label`s.
101   RETURN_IF_NOT_OK(ChannelLabelUtils::SelectConvertAndFillLabels(
102       audio_frame_metadata, channel_labels));
103 
104   return absl::OkStatus();
105 }
ValidateWavReaderIsConsistentWithData(absl::string_view wav_filename_for_debugging,const WavReader & wav_reader,const CodecConfigObu & codec_config,const std::vector<uint32_t> & channel_ids)106 absl::Status ValidateWavReaderIsConsistentWithData(
107     absl::string_view wav_filename_for_debugging, const WavReader& wav_reader,
108     const CodecConfigObu& codec_config,
109     const std::vector<uint32_t>& channel_ids) {
110   const std::string pretty_print_wav_filename =
111       absl::StrCat("WAV (", wav_filename_for_debugging, ")");
112   const int encoder_input_pcm_bit_depth =
113       static_cast<int>(codec_config.GetBitDepthToMeasureLoudness());
114   if (wav_reader.bit_depth() > encoder_input_pcm_bit_depth) {
115     return absl::InvalidArgumentError(absl::StrCat(
116         "Refusing to lower bit-depth of ", pretty_print_wav_filename,
117         " with bit_depth= ", wav_reader.bit_depth(),
118         " to bit_depth=", encoder_input_pcm_bit_depth));
119   }
120 
121   const uint32_t encoder_input_sample_rate = codec_config.GetInputSampleRate();
122   if (wav_reader.sample_rate_hz() != encoder_input_sample_rate) {
123     return absl::InvalidArgumentError(absl::StrCat(
124         pretty_print_wav_filename, "has a sample rate of ",
125         wav_reader.sample_rate_hz(), " Hz. Expected a sample rate of ",
126         encoder_input_sample_rate,
127         " Hz based on the Codec Config OBU. Consider using a third party "
128         "resampler on the WAV file, or picking Codec Config OBU settings to "
129         "match the WAV file before trying again."));
130   }
131 
132   const uint32_t decoder_output_sample_rate =
133       codec_config.GetOutputSampleRate();
134   if (encoder_input_sample_rate != decoder_output_sample_rate) {
135     return absl::InvalidArgumentError(absl::StrCat(
136         "Input and output sample rates differ: (", encoder_input_sample_rate,
137         " vs ", decoder_output_sample_rate, ")"));
138   }
139 
140   // To prevent indexing out of bounds after the `WavSampleProvider` is
141   // created, we ensure all user-specified channel IDs are in range of the
142   // number of channels in the input file.
143   for (const uint32_t channel_id : channel_ids) {
144     if (channel_id >= wav_reader.num_channels()) {
145       return absl::InvalidArgumentError(
146           absl::StrCat(pretty_print_wav_filename,
147                        " has num_channels= ", wav_reader.num_channels(),
148                        ". channel_id= ", channel_id, " is out of bounds."));
149     }
150   }
151 
152   return absl::OkStatus();
153 }
154 
155 // Fills in `channel_ids`, `labels`, and creates a `WavReader` from the input
156 // metadata and other input data.
InitializeForAudioElement(uint32_t audio_element_id,const iamf_tools_cli_proto::AudioFrameObuMetadata audio_frame_metadata,const std::string & wav_filename,const CodecConfigObu & codec_config,std::vector<uint32_t> & channel_ids,std::vector<ChannelLabel::Label> & labels,absl::flat_hash_map<DecodedUleb128,WavReader> & audio_element_id_to_wav_reader)157 absl::Status InitializeForAudioElement(
158     uint32_t audio_element_id,
159     const iamf_tools_cli_proto::AudioFrameObuMetadata audio_frame_metadata,
160     const std::string& wav_filename, const CodecConfigObu& codec_config,
161     std::vector<uint32_t>& channel_ids,
162     std::vector<ChannelLabel::Label>& labels,
163     absl::flat_hash_map<DecodedUleb128, WavReader>&
164         audio_element_id_to_wav_reader) {
165   RETURN_IF_NOT_OK(
166       FillChannelIdsAndLabels(audio_frame_metadata, channel_ids, labels));
167 
168   auto wav_reader = WavReader::CreateFromFile(
169       wav_filename, static_cast<size_t>(codec_config.GetNumSamplesPerFrame()));
170   if (!wav_reader.ok()) {
171     return wav_reader.status();
172   }
173   RETURN_IF_NOT_OK(ValidateWavReaderIsConsistentWithData(
174       wav_filename, *wav_reader, codec_config, channel_ids));
175 
176   audio_element_id_to_wav_reader.emplace(audio_element_id,
177                                          std::move(*wav_reader));
178 
179   return absl::OkStatus();
180 }
181 
182 }  // namespace
183 
Create(const::google::protobuf::RepeatedPtrField<iamf_tools_cli_proto::AudioFrameObuMetadata> & audio_frame_metadata,absl::string_view input_wav_directory,const absl::flat_hash_map<DecodedUleb128,AudioElementWithData> & audio_elements)184 absl::StatusOr<WavSampleProvider> WavSampleProvider::Create(
185     const ::google::protobuf::RepeatedPtrField<
186         iamf_tools_cli_proto::AudioFrameObuMetadata>& audio_frame_metadata,
187     absl::string_view input_wav_directory,
188     const absl::flat_hash_map<DecodedUleb128, AudioElementWithData>&
189         audio_elements) {
190   // Precompute, validate, and cache data for each audio element.
191   absl::flat_hash_map<DecodedUleb128, WavReader> wav_readers;
192   absl::flat_hash_map<DecodedUleb128, std::vector<uint32_t>>
193       audio_element_id_to_channel_ids;
194   absl::flat_hash_map<DecodedUleb128, std::vector<ChannelLabel::Label>>
195       audio_element_id_to_labels;
196 
197   const std::filesystem::path input_wav_directory_path(input_wav_directory);
198   for (const auto& audio_frame_obu_metadata : audio_frame_metadata) {
199     const uint32_t audio_element_id =
200         audio_frame_obu_metadata.audio_element_id();
201     const auto& wav_filename =
202         input_wav_directory_path /
203         std::filesystem::path(audio_frame_obu_metadata.wav_filename());
204 
205     // Retrieve the Codec Config OBU for the audio element.
206     auto audio_element_iter = audio_elements.find(audio_element_id);
207     if (audio_element_iter == audio_elements.end()) {
208       return absl::InvalidArgumentError(
209           absl::StrCat("No Audio Element found for ID= ",
210                        audio_frame_obu_metadata.audio_element_id()));
211     }
212     const CodecConfigObu* codec_config =
213         audio_element_iter->second.codec_config;
214     if (codec_config == nullptr) {
215       return absl::InvalidArgumentError(
216           absl::StrCat("No Codec Config found for Audio Element ID= ",
217                        audio_frame_obu_metadata.audio_element_id()));
218     }
219 
220     auto [channel_ids_iter, inserted] = audio_element_id_to_channel_ids.emplace(
221         audio_element_id, std::vector<uint32_t>());
222     if (!inserted) {
223       return absl::InvalidArgumentError(
224           absl::StrCat("List of AudioFrameObuMetadatahas contains duplicate "
225                        "Audio Element ID= ",
226                        audio_element_id));
227     }
228     // Internals add to the maps in parallel; if one had an empty slot, then
229     // the others will have an empty slot.
230 
231     RETURN_IF_NOT_OK(InitializeForAudioElement(
232         audio_element_id, audio_frame_obu_metadata, wav_filename.string(),
233         *codec_config, channel_ids_iter->second,
234         audio_element_id_to_labels[audio_element_id], wav_readers));
235   }
236   return WavSampleProvider(std::move(wav_readers),
237                            std::move(audio_element_id_to_channel_ids),
238                            std::move(audio_element_id_to_labels));
239 }
240 
ReadFrames(const DecodedUleb128 audio_element_id,LabelSamplesMap & labeled_samples,bool & finished_reading)241 absl::Status WavSampleProvider::ReadFrames(
242     const DecodedUleb128 audio_element_id, LabelSamplesMap& labeled_samples,
243     bool& finished_reading) {
244   auto wav_reader_iter = wav_readers_.find(audio_element_id);
245   if (wav_reader_iter == wav_readers_.end()) {
246     return absl::InvalidArgumentError(absl::StrCat(
247         "No WAV reader found for Audio Element ID= ", audio_element_id));
248   }
249   auto& wav_reader = wav_reader_iter->second;
250   const size_t samples_read = wav_reader.ReadFrame();
251   LOG_FIRST_N(INFO, 1) << samples_read << " samples read";
252 
253   // Note if the WAV reader is found for the Audio Element ID, then it's
254   // guaranteed to have the other corresponding metadata (otherwise the
255   // `Create()` would have failed).
256   const size_t num_time_ticks = samples_read / wav_reader.num_channels();
257   const auto& channel_ids =
258       audio_element_id_to_channel_ids_.at(audio_element_id);
259   const auto& channel_labels = audio_element_id_to_labels_.at(audio_element_id);
260   labeled_samples.clear();
261   for (int c = 0; c < channel_labels.size(); ++c) {
262     auto& samples = labeled_samples[channel_labels[c]];
263     samples.resize(num_time_ticks);
264     for (int t = 0; t < num_time_ticks; ++t) {
265       samples[t] = Int32ToNormalizedFloatingPoint<InternalSampleType>(
266           wav_reader.buffers_[t][channel_ids[c]]);
267     }
268   }
269   finished_reading = (wav_reader.remaining_samples() == 0);
270 
271   return absl::OkStatus();
272 }
273 
WavSampleProvider(absl::flat_hash_map<DecodedUleb128,WavReader> && wav_readers,absl::flat_hash_map<DecodedUleb128,std::vector<uint32_t>> && audio_element_id_to_channel_ids,absl::flat_hash_map<DecodedUleb128,std::vector<ChannelLabel::Label>> && audio_element_id_to_labels)274 WavSampleProvider::WavSampleProvider(
275     absl::flat_hash_map<DecodedUleb128, WavReader>&& wav_readers,
276     absl::flat_hash_map<DecodedUleb128, std::vector<uint32_t>>&&
277         audio_element_id_to_channel_ids,
278     absl::flat_hash_map<DecodedUleb128, std::vector<ChannelLabel::Label>>&&
279         audio_element_id_to_labels)
280     : wav_readers_(std::move(wav_readers)),
281       audio_element_id_to_channel_ids_(
282           std::move(audio_element_id_to_channel_ids)),
283       audio_element_id_to_labels_(std::move(audio_element_id_to_labels)) {};
284 
285 }  // namespace iamf_tools
286