1 /*
2 * Copyright (c) 2024, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 3-Clause Clear License
5 * and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
6 * License was not distributed with this source code in the LICENSE file, you
7 * can obtain it at www.aomedia.org/license/software-license/bsd-3-c-c. If the
8 * Alliance for Open Media Patent License 1.0 was not distributed with this
9 * source code in the PATENTS file, you can obtain it at
10 * www.aomedia.org/license/patent.
11 */
12
13 #include "iamf/cli/wav_sample_provider.h"
14
15 #include <cstddef>
16 #include <cstdint>
17 #include <filesystem>
18 #include <string>
19 #include <utility>
20 #include <vector>
21
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/log/log.h"
24 #include "absl/status/status.h"
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/string_view.h"
27 #include "iamf/cli/audio_element_with_data.h"
28 #include "iamf/cli/demixing_module.h"
29 #include "iamf/cli/proto/audio_frame.pb.h"
30 #include "iamf/cli/proto_conversion/channel_label_utils.h"
31 #include "iamf/cli/wav_reader.h"
32 #include "iamf/common/utils/macros.h"
33 #include "iamf/common/utils/numeric_utils.h"
34 #include "iamf/common/utils/validation_utils.h"
35 #include "iamf/obu/codec_config.h"
36 #include "iamf/obu/types.h"
37 #include "src/google/protobuf/repeated_ptr_field.h"
38
39 namespace iamf_tools {
40
41 namespace {
42
FillChannelIdsFromDeprecatedChannelIds(const iamf_tools_cli_proto::AudioFrameObuMetadata & audio_frame_metadata,std::vector<uint32_t> & channel_ids)43 absl::Status FillChannelIdsFromDeprecatedChannelIds(
44 const iamf_tools_cli_proto::AudioFrameObuMetadata& audio_frame_metadata,
45 std::vector<uint32_t>& channel_ids) {
46 if (audio_frame_metadata.channel_ids_size() !=
47 audio_frame_metadata.channel_labels_size()) {
48 return absl::InvalidArgumentError(
49 absl::StrCat("#channel IDs and #channel labels differ: (",
50 audio_frame_metadata.channel_ids_size(), " vs ",
51 audio_frame_metadata.channel_labels_size(), ")"));
52 }
53
54 // Precompute the channel IDs for the audio element.
55 channel_ids.reserve(audio_frame_metadata.channel_ids_size());
56 for (const uint32_t channel_id : audio_frame_metadata.channel_ids()) {
57 channel_ids.push_back(channel_id);
58 }
59
60 return absl::OkStatus();
61 }
62
FillChannelIdsFromChannelMetadatas(const iamf_tools_cli_proto::AudioFrameObuMetadata & audio_frame_metadata,std::vector<uint32_t> & channel_ids)63 absl::Status FillChannelIdsFromChannelMetadatas(
64 const iamf_tools_cli_proto::AudioFrameObuMetadata& audio_frame_metadata,
65 std::vector<uint32_t>& channel_ids) {
66 if (!audio_frame_metadata.channel_ids().empty()) {
67 return absl::InvalidArgumentError(
68 "Please fully upgrade to `channel_metadatas`. Leave `channel_ids` "
69 "empty");
70 }
71 // Collect the channel IDs from the `channel_metadatas` field.
72 channel_ids.reserve(audio_frame_metadata.channel_metadatas().size());
73 for (const auto& channel_metadata :
74 audio_frame_metadata.channel_metadatas()) {
75 channel_ids.push_back(channel_metadata.channel_id());
76 }
77 return absl::OkStatus();
78 }
79
FillChannelIdsAndLabels(const iamf_tools_cli_proto::AudioFrameObuMetadata & audio_frame_metadata,std::vector<uint32_t> & channel_ids,std::vector<ChannelLabel::Label> & channel_labels)80 absl::Status FillChannelIdsAndLabels(
81 const iamf_tools_cli_proto::AudioFrameObuMetadata& audio_frame_metadata,
82 std::vector<uint32_t>& channel_ids,
83 std::vector<ChannelLabel::Label>& channel_labels) {
84 // Collect the channel IDs.
85 if (!audio_frame_metadata.channel_metadatas().empty()) {
86 RETURN_IF_NOT_OK(
87 FillChannelIdsFromChannelMetadatas(audio_frame_metadata, channel_ids));
88 } else {
89 RETURN_IF_NOT_OK(FillChannelIdsFromDeprecatedChannelIds(
90 audio_frame_metadata, channel_ids));
91 }
92 if (!ValidateUnique(channel_ids.begin(), channel_ids.end(), "channel ids")
93 .ok()) {
94 // OK. The user is claiming some channel IDs are shared between labels.
95 // This is strange, but permitted.
96 LOG(WARNING) << "Usually channel labels should be unique. Did you use "
97 "the same channel ID for different channels?";
98 }
99
100 // Precompute the internal `ChannelLabel::Label`s.
101 RETURN_IF_NOT_OK(ChannelLabelUtils::SelectConvertAndFillLabels(
102 audio_frame_metadata, channel_labels));
103
104 return absl::OkStatus();
105 }
ValidateWavReaderIsConsistentWithData(absl::string_view wav_filename_for_debugging,const WavReader & wav_reader,const CodecConfigObu & codec_config,const std::vector<uint32_t> & channel_ids)106 absl::Status ValidateWavReaderIsConsistentWithData(
107 absl::string_view wav_filename_for_debugging, const WavReader& wav_reader,
108 const CodecConfigObu& codec_config,
109 const std::vector<uint32_t>& channel_ids) {
110 const std::string pretty_print_wav_filename =
111 absl::StrCat("WAV (", wav_filename_for_debugging, ")");
112 const int encoder_input_pcm_bit_depth =
113 static_cast<int>(codec_config.GetBitDepthToMeasureLoudness());
114 if (wav_reader.bit_depth() > encoder_input_pcm_bit_depth) {
115 return absl::InvalidArgumentError(absl::StrCat(
116 "Refusing to lower bit-depth of ", pretty_print_wav_filename,
117 " with bit_depth= ", wav_reader.bit_depth(),
118 " to bit_depth=", encoder_input_pcm_bit_depth));
119 }
120
121 const uint32_t encoder_input_sample_rate = codec_config.GetInputSampleRate();
122 if (wav_reader.sample_rate_hz() != encoder_input_sample_rate) {
123 return absl::InvalidArgumentError(absl::StrCat(
124 pretty_print_wav_filename, "has a sample rate of ",
125 wav_reader.sample_rate_hz(), " Hz. Expected a sample rate of ",
126 encoder_input_sample_rate,
127 " Hz based on the Codec Config OBU. Consider using a third party "
128 "resampler on the WAV file, or picking Codec Config OBU settings to "
129 "match the WAV file before trying again."));
130 }
131
132 const uint32_t decoder_output_sample_rate =
133 codec_config.GetOutputSampleRate();
134 if (encoder_input_sample_rate != decoder_output_sample_rate) {
135 return absl::InvalidArgumentError(absl::StrCat(
136 "Input and output sample rates differ: (", encoder_input_sample_rate,
137 " vs ", decoder_output_sample_rate, ")"));
138 }
139
140 // To prevent indexing out of bounds after the `WavSampleProvider` is
141 // created, we ensure all user-specified channel IDs are in range of the
142 // number of channels in the input file.
143 for (const uint32_t channel_id : channel_ids) {
144 if (channel_id >= wav_reader.num_channels()) {
145 return absl::InvalidArgumentError(
146 absl::StrCat(pretty_print_wav_filename,
147 " has num_channels= ", wav_reader.num_channels(),
148 ". channel_id= ", channel_id, " is out of bounds."));
149 }
150 }
151
152 return absl::OkStatus();
153 }
154
155 // Fills in `channel_ids`, `labels`, and creates a `WavReader` from the input
156 // metadata and other input data.
InitializeForAudioElement(uint32_t audio_element_id,const iamf_tools_cli_proto::AudioFrameObuMetadata audio_frame_metadata,const std::string & wav_filename,const CodecConfigObu & codec_config,std::vector<uint32_t> & channel_ids,std::vector<ChannelLabel::Label> & labels,absl::flat_hash_map<DecodedUleb128,WavReader> & audio_element_id_to_wav_reader)157 absl::Status InitializeForAudioElement(
158 uint32_t audio_element_id,
159 const iamf_tools_cli_proto::AudioFrameObuMetadata audio_frame_metadata,
160 const std::string& wav_filename, const CodecConfigObu& codec_config,
161 std::vector<uint32_t>& channel_ids,
162 std::vector<ChannelLabel::Label>& labels,
163 absl::flat_hash_map<DecodedUleb128, WavReader>&
164 audio_element_id_to_wav_reader) {
165 RETURN_IF_NOT_OK(
166 FillChannelIdsAndLabels(audio_frame_metadata, channel_ids, labels));
167
168 auto wav_reader = WavReader::CreateFromFile(
169 wav_filename, static_cast<size_t>(codec_config.GetNumSamplesPerFrame()));
170 if (!wav_reader.ok()) {
171 return wav_reader.status();
172 }
173 RETURN_IF_NOT_OK(ValidateWavReaderIsConsistentWithData(
174 wav_filename, *wav_reader, codec_config, channel_ids));
175
176 audio_element_id_to_wav_reader.emplace(audio_element_id,
177 std::move(*wav_reader));
178
179 return absl::OkStatus();
180 }
181
182 } // namespace
183
Create(const::google::protobuf::RepeatedPtrField<iamf_tools_cli_proto::AudioFrameObuMetadata> & audio_frame_metadata,absl::string_view input_wav_directory,const absl::flat_hash_map<DecodedUleb128,AudioElementWithData> & audio_elements)184 absl::StatusOr<WavSampleProvider> WavSampleProvider::Create(
185 const ::google::protobuf::RepeatedPtrField<
186 iamf_tools_cli_proto::AudioFrameObuMetadata>& audio_frame_metadata,
187 absl::string_view input_wav_directory,
188 const absl::flat_hash_map<DecodedUleb128, AudioElementWithData>&
189 audio_elements) {
190 // Precompute, validate, and cache data for each audio element.
191 absl::flat_hash_map<DecodedUleb128, WavReader> wav_readers;
192 absl::flat_hash_map<DecodedUleb128, std::vector<uint32_t>>
193 audio_element_id_to_channel_ids;
194 absl::flat_hash_map<DecodedUleb128, std::vector<ChannelLabel::Label>>
195 audio_element_id_to_labels;
196
197 const std::filesystem::path input_wav_directory_path(input_wav_directory);
198 for (const auto& audio_frame_obu_metadata : audio_frame_metadata) {
199 const uint32_t audio_element_id =
200 audio_frame_obu_metadata.audio_element_id();
201 const auto& wav_filename =
202 input_wav_directory_path /
203 std::filesystem::path(audio_frame_obu_metadata.wav_filename());
204
205 // Retrieve the Codec Config OBU for the audio element.
206 auto audio_element_iter = audio_elements.find(audio_element_id);
207 if (audio_element_iter == audio_elements.end()) {
208 return absl::InvalidArgumentError(
209 absl::StrCat("No Audio Element found for ID= ",
210 audio_frame_obu_metadata.audio_element_id()));
211 }
212 const CodecConfigObu* codec_config =
213 audio_element_iter->second.codec_config;
214 if (codec_config == nullptr) {
215 return absl::InvalidArgumentError(
216 absl::StrCat("No Codec Config found for Audio Element ID= ",
217 audio_frame_obu_metadata.audio_element_id()));
218 }
219
220 auto [channel_ids_iter, inserted] = audio_element_id_to_channel_ids.emplace(
221 audio_element_id, std::vector<uint32_t>());
222 if (!inserted) {
223 return absl::InvalidArgumentError(
224 absl::StrCat("List of AudioFrameObuMetadatahas contains duplicate "
225 "Audio Element ID= ",
226 audio_element_id));
227 }
228 // Internals add to the maps in parallel; if one had an empty slot, then
229 // the others will have an empty slot.
230
231 RETURN_IF_NOT_OK(InitializeForAudioElement(
232 audio_element_id, audio_frame_obu_metadata, wav_filename.string(),
233 *codec_config, channel_ids_iter->second,
234 audio_element_id_to_labels[audio_element_id], wav_readers));
235 }
236 return WavSampleProvider(std::move(wav_readers),
237 std::move(audio_element_id_to_channel_ids),
238 std::move(audio_element_id_to_labels));
239 }
240
ReadFrames(const DecodedUleb128 audio_element_id,LabelSamplesMap & labeled_samples,bool & finished_reading)241 absl::Status WavSampleProvider::ReadFrames(
242 const DecodedUleb128 audio_element_id, LabelSamplesMap& labeled_samples,
243 bool& finished_reading) {
244 auto wav_reader_iter = wav_readers_.find(audio_element_id);
245 if (wav_reader_iter == wav_readers_.end()) {
246 return absl::InvalidArgumentError(absl::StrCat(
247 "No WAV reader found for Audio Element ID= ", audio_element_id));
248 }
249 auto& wav_reader = wav_reader_iter->second;
250 const size_t samples_read = wav_reader.ReadFrame();
251 LOG_FIRST_N(INFO, 1) << samples_read << " samples read";
252
253 // Note if the WAV reader is found for the Audio Element ID, then it's
254 // guaranteed to have the other corresponding metadata (otherwise the
255 // `Create()` would have failed).
256 const size_t num_time_ticks = samples_read / wav_reader.num_channels();
257 const auto& channel_ids =
258 audio_element_id_to_channel_ids_.at(audio_element_id);
259 const auto& channel_labels = audio_element_id_to_labels_.at(audio_element_id);
260 labeled_samples.clear();
261 for (int c = 0; c < channel_labels.size(); ++c) {
262 auto& samples = labeled_samples[channel_labels[c]];
263 samples.resize(num_time_ticks);
264 for (int t = 0; t < num_time_ticks; ++t) {
265 samples[t] = Int32ToNormalizedFloatingPoint<InternalSampleType>(
266 wav_reader.buffers_[t][channel_ids[c]]);
267 }
268 }
269 finished_reading = (wav_reader.remaining_samples() == 0);
270
271 return absl::OkStatus();
272 }
273
WavSampleProvider(absl::flat_hash_map<DecodedUleb128,WavReader> && wav_readers,absl::flat_hash_map<DecodedUleb128,std::vector<uint32_t>> && audio_element_id_to_channel_ids,absl::flat_hash_map<DecodedUleb128,std::vector<ChannelLabel::Label>> && audio_element_id_to_labels)274 WavSampleProvider::WavSampleProvider(
275 absl::flat_hash_map<DecodedUleb128, WavReader>&& wav_readers,
276 absl::flat_hash_map<DecodedUleb128, std::vector<uint32_t>>&&
277 audio_element_id_to_channel_ids,
278 absl::flat_hash_map<DecodedUleb128, std::vector<ChannelLabel::Label>>&&
279 audio_element_id_to_labels)
280 : wav_readers_(std::move(wav_readers)),
281 audio_element_id_to_channel_ids_(
282 std::move(audio_element_id_to_channel_ids)),
283 audio_element_id_to_labels_(std::move(audio_element_id_to_labels)) {};
284
285 } // namespace iamf_tools
286