1 /*
2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 3-Clause Clear License
5 * and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
6 * License was not distributed with this source code in the LICENSE file, you
7 * can obtain it at www.aomedia.org/license/software-license/bsd-3-c-c. If the
8 * Alliance for Open Media Patent License 1.0 was not distributed with this
9 * source code in the PATENTS file, you can obtain it at
10 * www.aomedia.org/license/patent.
11 */
12 #include "iamf/cli/proto_conversion/proto_to_obu/parameter_block_generator.h"
13
14 #include <cstdint>
15 #include <list>
16 #include <memory>
17 #include <optional>
18 #include <utility>
19 #include <variant>
20 #include <vector>
21
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/log/check.h"
25 #include "absl/log/log.h"
26 #include "absl/status/status.h"
27 #include "absl/strings/str_cat.h"
28 #include "iamf/cli/audio_element_with_data.h"
29 #include "iamf/cli/channel_label.h"
30 #include "iamf/cli/cli_util.h"
31 #include "iamf/cli/demixing_module.h"
32 #include "iamf/cli/global_timing_module.h"
33 #include "iamf/cli/parameter_block_with_data.h"
34 #include "iamf/cli/proto/parameter_block.pb.h"
35 #include "iamf/cli/proto/parameter_data.pb.h"
36 #include "iamf/cli/proto_conversion/proto_utils.h"
37 #include "iamf/cli/recon_gain_generator.h"
38 #include "iamf/common/utils/macros.h"
39 #include "iamf/common/utils/numeric_utils.h"
40 #include "iamf/common/utils/validation_utils.h"
41 #include "iamf/obu/demixing_info_parameter_data.h"
42 #include "iamf/obu/demixing_param_definition.h"
43 #include "iamf/obu/mix_gain_parameter_data.h"
44 #include "iamf/obu/param_definition_variant.h"
45 #include "iamf/obu/param_definitions.h"
46 #include "iamf/obu/parameter_block.h"
47 #include "iamf/obu/recon_gain_info_parameter_data.h"
48 #include "iamf/obu/types.h"
49
50 namespace iamf_tools {
51
52 namespace {
53
54 std::optional<ParamDefinition::ParameterDefinitionType>
GetParameterDefinitionType(const ParamDefinitionVariant & parameter_definition_variant)55 GetParameterDefinitionType(
56 const ParamDefinitionVariant& parameter_definition_variant) {
57 return std::visit(
58 [](const auto& param_definition) { return param_definition.GetType(); },
59 parameter_definition_variant);
60 }
61
GetParameterDefinitionMode(const ParamDefinitionVariant & parameter_definition_variant)62 uint8_t GetParameterDefinitionMode(
63 const ParamDefinitionVariant& parameter_definition_variant) {
64 return std::visit(
65 [](const auto& param_definition) {
66 return param_definition.param_definition_mode_;
67 },
68 parameter_definition_variant);
69 }
70
GenerateMixGainSubblock(const iamf_tools_cli_proto::MixGainParameterData & metadata_mix_gain_parameter_data,const MixGainParamDefinition * param_definition,std::unique_ptr<ParameterData> & parameter_data)71 absl::Status GenerateMixGainSubblock(
72 const iamf_tools_cli_proto::MixGainParameterData&
73 metadata_mix_gain_parameter_data,
74 const MixGainParamDefinition* param_definition,
75 std::unique_ptr<ParameterData>& parameter_data) {
76 parameter_data = param_definition->CreateParameterData();
77 auto* mix_gain_parameter_data =
78 static_cast<MixGainParameterData*>(parameter_data.get());
79 switch (metadata_mix_gain_parameter_data.animation_type()) {
80 using enum iamf_tools_cli_proto::AnimationType;
81 case ANIMATE_STEP: {
82 const auto& metadata_animation =
83 metadata_mix_gain_parameter_data.param_data().step();
84 mix_gain_parameter_data->animation_type =
85 MixGainParameterData::kAnimateStep;
86 AnimationStepInt16 obu_animation;
87 RETURN_IF_NOT_OK(StaticCastIfInRange<int32_t, int16_t>(
88 "AnimationStepInt16.start_point_value",
89 metadata_animation.start_point_value(),
90 obu_animation.start_point_value));
91 mix_gain_parameter_data->param_data = obu_animation;
92 break;
93 }
94 case ANIMATE_LINEAR: {
95 const auto& metadata_animation =
96 metadata_mix_gain_parameter_data.param_data().linear();
97 mix_gain_parameter_data->animation_type =
98 MixGainParameterData::kAnimateLinear;
99
100 AnimationLinearInt16 obu_animation;
101 RETURN_IF_NOT_OK(StaticCastIfInRange<int32_t, int16_t>(
102 "AnimationLinearInt16.start_point_value",
103 metadata_animation.start_point_value(),
104 obu_animation.start_point_value));
105 RETURN_IF_NOT_OK(StaticCastIfInRange<int32_t, int16_t>(
106 "AnimationLinearInt16.end_point_value",
107 metadata_animation.end_point_value(), obu_animation.end_point_value));
108 mix_gain_parameter_data->param_data = obu_animation;
109 break;
110 }
111 case ANIMATE_BEZIER: {
112 const auto& metadata_animation =
113 metadata_mix_gain_parameter_data.param_data().bezier();
114 mix_gain_parameter_data->animation_type =
115 MixGainParameterData::kAnimateBezier;
116 AnimationBezierInt16 obu_animation;
117 RETURN_IF_NOT_OK(StaticCastIfInRange<int32_t, int16_t>(
118 "AnimationBezierInt16.start_point_value",
119 metadata_animation.start_point_value(),
120 obu_animation.start_point_value));
121 RETURN_IF_NOT_OK(StaticCastIfInRange<int32_t, int16_t>(
122 "AnimationBezierInt16.end_point_value",
123 metadata_animation.end_point_value(), obu_animation.end_point_value));
124 RETURN_IF_NOT_OK(StaticCastIfInRange<int32_t, int16_t>(
125 "AnimationBezierInt16.control_point_value",
126 metadata_animation.control_point_value(),
127 obu_animation.control_point_value));
128 RETURN_IF_NOT_OK(StaticCastIfInRange<uint32_t, uint8_t>(
129 "AnimationBezierInt16.control_point_relative_time",
130 metadata_animation.control_point_relative_time(),
131 obu_animation.control_point_relative_time));
132 mix_gain_parameter_data->param_data = obu_animation;
133 break;
134 }
135 default:
136 return absl::InvalidArgumentError(
137 absl::StrCat("Unrecognized animation type= ",
138 metadata_mix_gain_parameter_data.animation_type()));
139 }
140
141 return absl::OkStatus();
142 }
143
FindDemixedChannels(const ChannelNumbers & accumulated_channels,const ChannelNumbers & layer_channels,std::list<ChannelLabel::Label> * const demixed_channel_labels)144 absl::Status FindDemixedChannels(
145 const ChannelNumbers& accumulated_channels,
146 const ChannelNumbers& layer_channels,
147 std::list<ChannelLabel::Label>* const demixed_channel_labels) {
148 using enum ChannelLabel::Label;
149 for (int surround = accumulated_channels.surround + 1;
150 surround <= layer_channels.surround; surround++) {
151 switch (surround) {
152 case 2:
153 // Previous layer is Mono, this layer is Stereo.
154 if (accumulated_channels.surround == 1) {
155 demixed_channel_labels->push_back(kDemixedR2);
156 }
157 break;
158 case 3:
159 demixed_channel_labels->push_back(kDemixedL3);
160 demixed_channel_labels->push_back(kDemixedR3);
161 break;
162 case 5:
163 demixed_channel_labels->push_back(kDemixedLs5);
164 demixed_channel_labels->push_back(kDemixedRs5);
165 break;
166 case 7:
167 demixed_channel_labels->push_back(kDemixedL7);
168 demixed_channel_labels->push_back(kDemixedR7);
169 demixed_channel_labels->push_back(kDemixedLrs7);
170 demixed_channel_labels->push_back(kDemixedRrs7);
171 break;
172 default:
173 if (surround > 7) {
174 return absl::InvalidArgumentError(absl::StrCat(
175 "Unsupported number of surround channels: ", surround));
176 }
177 break;
178 }
179 }
180
181 if (accumulated_channels.height == 2) {
182 if (layer_channels.height == 4) {
183 demixed_channel_labels->push_back(kDemixedLtb4);
184 demixed_channel_labels->push_back(kDemixedRtb4);
185 } else if (layer_channels.height == 2 &&
186 accumulated_channels.surround == 3 &&
187 layer_channels.surround > 3) {
188 demixed_channel_labels->push_back(kDemixedLtf2);
189 demixed_channel_labels->push_back(kDemixedRtf2);
190 }
191 }
192
193 return absl::OkStatus();
194 }
195
ConvertReconGainsAndFlags(const bool additional_logging,const absl::flat_hash_map<ChannelLabel::Label,double> & label_to_recon_gain,std::vector<uint8_t> & computed_recon_gains,DecodedUleb128 & computed_recon_gain_flag)196 absl::Status ConvertReconGainsAndFlags(
197 const bool additional_logging,
198 const absl::flat_hash_map<ChannelLabel::Label, double>& label_to_recon_gain,
199 std::vector<uint8_t>& computed_recon_gains,
200 DecodedUleb128& computed_recon_gain_flag) {
201 computed_recon_gains.resize(12, 0);
202 computed_recon_gain_flag = 0;
203 for (const auto& [label, recon_gain] : label_to_recon_gain) {
204 LOG_IF(INFO, additional_logging)
205 << "Recon Gain[" << label << "]= " << recon_gain;
206
207 // Bit position is based on Figure 5 of the Spec.
208 int bit_position = 0;
209 switch (label) {
210 using enum ChannelLabel::Label;
211 case kDemixedL7:
212 case kDemixedL5:
213 case kDemixedL3:
214 // `kDemixedL2` is never demixed.
215 bit_position = 0;
216 break;
217 case kDemixedR7:
218 case kDemixedR5:
219 case kDemixedR3:
220 case kDemixedR2:
221 // `kCentre` is never demixed. Skipping bit position = 1.
222 bit_position = 2;
223 break;
224 case kDemixedLs5:
225 bit_position = 3;
226 break;
227 case kDemixedRs5:
228 bit_position = 4;
229 break;
230 case kDemixedLtf2:
231 bit_position = 5;
232 break;
233 case kDemixedRtf2:
234 bit_position = 6;
235 break;
236 case kDemixedLrs7:
237 bit_position = 7;
238 break;
239 case kDemixedRrs7:
240 bit_position = 8;
241 break;
242 case kDemixedLtb4:
243 bit_position = 9;
244 break;
245 case kDemixedRtb4:
246 bit_position = 10;
247 // `kLFE` is never demixed. Skipping bit position = 11.
248 break;
249 default:
250 LOG(ERROR) << "Unrecognized demixed channel label: " << label;
251 }
252 computed_recon_gain_flag |= 1 << bit_position;
253 computed_recon_gains[bit_position] =
254 static_cast<uint8_t>(recon_gain * 255.0);
255 }
256 return absl::OkStatus();
257 }
258
ComputeReconGains(const int layer_index,const ChannelNumbers & layer_channels,const ChannelNumbers & accumulated_channels,const bool additional_recon_gains_logging,const LabelSamplesMap & labeled_samples,const LabelSamplesMap & label_to_decoded_samples,const std::vector<bool> & recon_gain_is_present_flags,std::vector<uint8_t> & computed_recon_gains,DecodedUleb128 & computed_recon_gain_flag)259 absl::Status ComputeReconGains(
260 const int layer_index, const ChannelNumbers& layer_channels,
261 const ChannelNumbers& accumulated_channels,
262 const bool additional_recon_gains_logging,
263 const LabelSamplesMap& labeled_samples,
264 const LabelSamplesMap& label_to_decoded_samples,
265 const std::vector<bool>& recon_gain_is_present_flags,
266 std::vector<uint8_t>& computed_recon_gains,
267 DecodedUleb128& computed_recon_gain_flag) {
268 if (additional_recon_gains_logging) {
269 LogChannelNumbers(absl::StrCat("Layer[", layer_index, "]"), layer_channels);
270 }
271 absl::flat_hash_map<ChannelLabel::Label, double> label_to_recon_gain;
272 if (layer_index > 0) {
273 std::list<ChannelLabel::Label> demixed_channel_labels;
274 RETURN_IF_NOT_OK(FindDemixedChannels(accumulated_channels, layer_channels,
275 &demixed_channel_labels));
276
277 LOG_IF(INFO, additional_recon_gains_logging) << "Demixed channels: ";
278 for (const auto& label : demixed_channel_labels) {
279 RETURN_IF_NOT_OK(ReconGainGenerator::ComputeReconGain(
280 label, labeled_samples, label_to_decoded_samples,
281 additional_recon_gains_logging, label_to_recon_gain[label]));
282 }
283 }
284
285 if (recon_gain_is_present_flags[layer_index] !=
286 (!label_to_recon_gain.empty())) {
287 return absl::InvalidArgumentError(absl::StrCat(
288 "Mismatch of whether user specified recon gain is present: ",
289 recon_gain_is_present_flags[layer_index],
290 " vs whether recon gain should be computed: ",
291 !label_to_recon_gain.empty()));
292 }
293
294 RETURN_IF_NOT_OK(ConvertReconGainsAndFlags(
295 /*additional_logging=*/true, label_to_recon_gain, computed_recon_gains,
296 computed_recon_gain_flag));
297
298 return absl::OkStatus();
299 }
300
GenerateReconGainSubblock(const bool override_computed_recon_gains,const bool additional_recon_gains_logging,const IdLabeledFrameMap & id_to_labeled_frame,const IdLabeledFrameMap & id_to_labeled_decoded_frame,const iamf_tools_cli_proto::ReconGainInfoParameterData & metadata_recon_gain_info_parameter_data,const ReconGainParamDefinition * param_definition,std::unique_ptr<ParameterData> & parameter_data)301 absl::Status GenerateReconGainSubblock(
302 const bool override_computed_recon_gains,
303 const bool additional_recon_gains_logging,
304 const IdLabeledFrameMap& id_to_labeled_frame,
305 const IdLabeledFrameMap& id_to_labeled_decoded_frame,
306 const iamf_tools_cli_proto::ReconGainInfoParameterData&
307 metadata_recon_gain_info_parameter_data,
308 const ReconGainParamDefinition* param_definition,
309 std::unique_ptr<ParameterData>& parameter_data) {
310 parameter_data = param_definition->CreateParameterData();
311 auto* recon_gain_info_parameter_data =
312 static_cast<ReconGainInfoParameterData*>(parameter_data.get());
313 const auto num_layers = param_definition->aux_data_.size();
314 const auto& user_recon_gains_layers =
315 metadata_recon_gain_info_parameter_data.recon_gains_for_layer();
316 if (num_layers > 1 && num_layers != user_recon_gains_layers.size()) {
317 return absl::InvalidArgumentError(
318 absl::StrCat("There are ", num_layers, " layers of scalable ",
319 "audio element, but the user only specifies ",
320 user_recon_gains_layers.size(), " layers."));
321 }
322 recon_gain_info_parameter_data->recon_gain_elements.resize(num_layers);
323
324 const std::vector<bool>& recon_gain_is_present_flags =
325 recon_gain_info_parameter_data->recon_gain_is_present_flags;
326 for (int layer_index = 0; layer_index < num_layers; layer_index++) {
327 // Write out the user supplied gains. Depending on the mode these either
328 // match the computed recon gains or are used as an override. Write to
329 // output.
330 auto& output_recon_gain_element =
331 recon_gain_info_parameter_data->recon_gain_elements[layer_index];
332 if (!param_definition->aux_data_[layer_index].recon_gain_is_present_flag) {
333 // Skip computation and store no value in the output.
334 output_recon_gain_element.reset();
335 continue;
336 }
337 output_recon_gain_element.emplace(ReconGainElement{});
338
339 // Construct the bitmask indicating the channels where recon gains are
340 // present.
341 std::vector<uint8_t> user_recon_gains(12, 0);
342 DecodedUleb128 user_recon_gain_flag = 0;
343 for (const auto& [bit_position, user_recon_gain] :
344 user_recon_gains_layers[layer_index].recon_gain()) {
345 user_recon_gain_flag |= 1 << bit_position;
346 user_recon_gains[bit_position] = user_recon_gain;
347 }
348 for (const auto& [bit_position, user_recon_gain] :
349 user_recon_gains_layers[layer_index].recon_gain()) {
350 output_recon_gain_element->recon_gain[bit_position] =
351 user_recon_gains[bit_position];
352 }
353 output_recon_gain_element->recon_gain_flag = user_recon_gain_flag;
354
355 if (override_computed_recon_gains) {
356 continue;
357 }
358
359 // Compute the recon gains and validate they match the user supplied values.
360 std::vector<uint8_t> computed_recon_gains;
361 DecodedUleb128 computed_recon_gain_flag = 0;
362 const DecodedUleb128 audio_element_id = param_definition->audio_element_id_;
363 const auto labeled_frame_iter = id_to_labeled_frame.find(audio_element_id);
364 const auto labeled_decoded_frame_iter =
365 id_to_labeled_decoded_frame.find(audio_element_id);
366 if (labeled_frame_iter == id_to_labeled_frame.end() ||
367 labeled_decoded_frame_iter == id_to_labeled_decoded_frame.end()) {
368 return absl::InvalidArgumentError(absl::StrCat(
369 "Original or decoded audio frame for audio element ID= ",
370 audio_element_id, " not found when computing recon gains"));
371 }
372
373 const auto& layer_channels =
374 param_definition->aux_data_[layer_index].channel_numbers_for_layer;
375 const auto accumulated_channels =
376 (layer_index > 0 ? param_definition->aux_data_[layer_index - 1]
377 .channel_numbers_for_layer
378 : ChannelNumbers{0, 0, 0});
379 RETURN_IF_NOT_OK(
380 ComputeReconGains(layer_index, layer_channels, accumulated_channels,
381 additional_recon_gains_logging,
382 labeled_frame_iter->second.label_to_samples,
383 labeled_decoded_frame_iter->second.label_to_samples,
384 recon_gain_is_present_flags, computed_recon_gains,
385 computed_recon_gain_flag));
386
387 // Compare computed and user specified flag and recon gain values.
388 if (computed_recon_gain_flag != user_recon_gain_flag) {
389 return absl::InvalidArgumentError(absl::StrCat(
390 "Computed recon gain flag different from what user specified: ",
391 computed_recon_gain_flag, " vs ", user_recon_gain_flag));
392 }
393 bool recon_gains_match = true;
394 for (int i = 0; i < 12; i++) {
395 if (user_recon_gains[i] != computed_recon_gains[i]) {
396 // Find all mismatches before returning an error.
397 LOG(ERROR) << "Computed recon gain [" << i
398 << "] different from what user specified: "
399 << absl::StrCat(computed_recon_gains[i]) << " vs "
400 << absl::StrCat(user_recon_gains[i]);
401 recon_gains_match = false;
402 }
403 }
404 if (!recon_gains_match) {
405 return absl::InvalidArgumentError("Recon gains mismatch");
406 }
407 } // End of for (int layer_index ...)
408
409 return absl::OkStatus();
410 }
411
GenerateParameterBlockSubblock(const bool override_computed_recon_gains,const bool additional_recon_gains_logging,const IdLabeledFrameMap * id_to_labeled_frame,const IdLabeledFrameMap * id_to_labeled_decoded_frame,const ParamDefinitionVariant & param_definition_variant,const bool include_subblock_duration,const int subblock_index,const iamf_tools_cli_proto::ParameterSubblock & metadata_subblock,ParameterBlockObu & obu)412 absl::Status GenerateParameterBlockSubblock(
413 const bool override_computed_recon_gains,
414 const bool additional_recon_gains_logging,
415 const IdLabeledFrameMap* id_to_labeled_frame,
416 const IdLabeledFrameMap* id_to_labeled_decoded_frame,
417 const ParamDefinitionVariant& param_definition_variant,
418 const bool include_subblock_duration, const int subblock_index,
419 const iamf_tools_cli_proto::ParameterSubblock& metadata_subblock,
420 ParameterBlockObu& obu) {
421 if (include_subblock_duration) {
422 RETURN_IF_NOT_OK(obu.SetSubblockDuration(
423 subblock_index, metadata_subblock.subblock_duration()));
424 }
425
426 auto& obu_subblock_param_data = obu.subblocks_[subblock_index].param_data;
427 const auto param_definition_type =
428 GetParameterDefinitionType(param_definition_variant);
429 std::unique_ptr<ParameterData> parameter_data;
430 RETURN_IF_NOT_OK(
431 ValidateHasValue(param_definition_type, "`param_definition_type`."));
432 switch (*param_definition_type) {
433 using enum ParamDefinition::ParameterDefinitionType;
434 case kParameterDefinitionMixGain: {
435 auto* mix_gain_param_definition =
436 std::get_if<MixGainParamDefinition>(¶m_definition_variant);
437 RETURN_IF_NOT_OK(
438 ValidateNotNull(mix_gain_param_definition, "MixGainParamDefinition"));
439 RETURN_IF_NOT_OK(
440 GenerateMixGainSubblock(metadata_subblock.mix_gain_parameter_data(),
441 mix_gain_param_definition, parameter_data));
442 break;
443 }
444 case kParameterDefinitionDemixing: {
445 if (subblock_index > 1) {
446 return absl::InvalidArgumentError(
447 "There should be only one subblock for demixing info.");
448 }
449 auto* demixing_param_definition =
450 std::get_if<DemixingParamDefinition>(¶m_definition_variant);
451 RETURN_IF_NOT_OK(ValidateNotNull(demixing_param_definition,
452 "DemixingParamDefinition"));
453 parameter_data = demixing_param_definition->CreateParameterData();
454 RETURN_IF_NOT_OK(CopyDemixingInfoParameterData(
455 metadata_subblock.demixing_info_parameter_data(),
456 *static_cast<DemixingInfoParameterData*>(parameter_data.get())));
457 break;
458 }
459 case kParameterDefinitionReconGain: {
460 if (subblock_index > 1) {
461 return absl::InvalidArgumentError(
462 "There should be only one subblock for recon gain info.");
463 }
464 auto* recon_gain_param_definition =
465 std::get_if<ReconGainParamDefinition>(¶m_definition_variant);
466 RETURN_IF_NOT_OK(ValidateNotNull(recon_gain_param_definition,
467 "ReconGainParamDefinition"));
468 RETURN_IF_NOT_OK(GenerateReconGainSubblock(
469 override_computed_recon_gains, additional_recon_gains_logging,
470 *id_to_labeled_frame, *id_to_labeled_decoded_frame,
471 metadata_subblock.recon_gain_info_parameter_data(),
472 recon_gain_param_definition, parameter_data));
473 break;
474 }
475 default:
476 // TODO(b/289080630): Support the extension fields here.
477 return absl::InvalidArgumentError(absl::StrCat(
478 "Unsupported param definition type= ", *param_definition_type));
479 }
480 obu_subblock_param_data = std::move(parameter_data);
481
482 return absl::OkStatus();
483 }
484
PopulateCommonFields(const iamf_tools_cli_proto::ParameterBlockObuMetadata & parameter_block_metadata,const ParamDefinition & param_definition,GlobalTimingModule & global_timing_module,ParameterBlockWithData & parameter_block_with_data)485 absl::Status PopulateCommonFields(
486 const iamf_tools_cli_proto::ParameterBlockObuMetadata&
487 parameter_block_metadata,
488 const ParamDefinition& param_definition,
489 GlobalTimingModule& global_timing_module,
490 ParameterBlockWithData& parameter_block_with_data) {
491 // Get the duration from the parameter definition or the OBU itself as
492 // applicable.
493 const DecodedUleb128 duration = param_definition.param_definition_mode_ == 1
494 ? parameter_block_metadata.duration()
495 : param_definition.duration_;
496
497 // Populate the timing information.
498 RETURN_IF_NOT_OK(global_timing_module.GetNextParameterBlockTimestamps(
499 parameter_block_metadata.parameter_id(),
500 parameter_block_metadata.start_timestamp(), duration,
501 parameter_block_with_data.start_timestamp,
502 parameter_block_with_data.end_timestamp));
503
504 // Populate the OBU.
505 const DecodedUleb128 parameter_id = parameter_block_metadata.parameter_id();
506 parameter_block_with_data.obu = std::make_unique<ParameterBlockObu>(
507 GetHeaderFromMetadata(parameter_block_metadata.obu_header()),
508 parameter_id, param_definition);
509
510 // Several fields are dependent on `param_definition_mode`.
511 if (param_definition.param_definition_mode_ == 1) {
512 RETURN_IF_NOT_OK(parameter_block_with_data.obu->InitializeSubblocks(
513 parameter_block_metadata.duration(),
514 parameter_block_metadata.constant_subblock_duration(),
515 parameter_block_metadata.num_subblocks()));
516 } else {
517 RETURN_IF_NOT_OK(parameter_block_with_data.obu->InitializeSubblocks());
518 }
519
520 return absl::OkStatus();
521 }
522
PopulateSubblocks(const iamf_tools_cli_proto::ParameterBlockObuMetadata & parameter_block_metadata,const bool override_computed_recon_gains,const bool additional_recon_gains_logging,const IdLabeledFrameMap * id_to_labeled_frame,const IdLabeledFrameMap * id_to_labeled_decoded_frame,const ParamDefinitionVariant & param_definition_variant,ParameterBlockWithData & output_parameter_block)523 absl::Status PopulateSubblocks(
524 const iamf_tools_cli_proto::ParameterBlockObuMetadata&
525 parameter_block_metadata,
526 const bool override_computed_recon_gains,
527 const bool additional_recon_gains_logging,
528 const IdLabeledFrameMap* id_to_labeled_frame,
529 const IdLabeledFrameMap* id_to_labeled_decoded_frame,
530 const ParamDefinitionVariant& param_definition_variant,
531 ParameterBlockWithData& output_parameter_block) {
532 auto& parameter_block_obu = *output_parameter_block.obu;
533 const DecodedUleb128 num_subblocks = parameter_block_obu.GetNumSubblocks();
534
535 // All subblocks will include `subblock_duration` or none will include it.
536 const bool include_subblock_duration =
537 GetParameterDefinitionMode(param_definition_variant) == 1 &&
538 parameter_block_obu.GetConstantSubblockDuration() == 0;
539
540 if (num_subblocks != parameter_block_metadata.subblocks_size()) {
541 return absl::InvalidArgumentError(
542 absl::StrCat("Expected ", num_subblocks, " subblocks, got ",
543 parameter_block_metadata.subblocks_size()));
544 }
545 for (int i = 0; i < num_subblocks; ++i) {
546 RETURN_IF_NOT_OK(GenerateParameterBlockSubblock(
547 override_computed_recon_gains, additional_recon_gains_logging,
548 id_to_labeled_frame, id_to_labeled_decoded_frame,
549 param_definition_variant, include_subblock_duration, i,
550 parameter_block_metadata.subblocks(i), parameter_block_obu));
551 }
552
553 return absl::OkStatus();
554 }
555
LogParameterBlockObus(const std::list<ParameterBlockWithData> & output_parameter_blocks)556 absl::Status LogParameterBlockObus(
557 const std::list<ParameterBlockWithData>& output_parameter_blocks) {
558 // Log only the first and the last parameter blocks.
559 if (output_parameter_blocks.empty()) {
560 return absl::OkStatus();
561 }
562 std::vector<const ParameterBlockWithData*> to_log = {
563 &output_parameter_blocks.front()};
564 if (output_parameter_blocks.size() > 1) {
565 to_log.push_back(&output_parameter_blocks.back());
566 }
567
568 for (const auto* parameter_block_with_data : to_log) {
569 parameter_block_with_data->obu->PrintObu();
570 LOG(INFO) << " // start_timestamp= "
571 << parameter_block_with_data->start_timestamp;
572 LOG(INFO) << " // end_timestamp= "
573 << parameter_block_with_data->end_timestamp;
574 }
575
576 return absl::OkStatus();
577 }
578
579 } // namespace
580
Initialize(const absl::flat_hash_map<DecodedUleb128,AudioElementWithData> & audio_elements)581 absl::Status ParameterBlockGenerator::Initialize(
582 const absl::flat_hash_map<DecodedUleb128, AudioElementWithData>&
583 audio_elements) {
584 for (const auto& [parameter_id, param_definition_variant] :
585 param_definition_variants_) {
586 const auto param_definition_type =
587 GetParameterDefinitionType(param_definition_variant);
588 RETURN_IF_NOT_OK(
589 ValidateHasValue(param_definition_type, "param_definition_type"));
590 if (param_definition_type !=
591 ParamDefinition::kParameterDefinitionDemixing &&
592 param_definition_type != ParamDefinition::kParameterDefinitionMixGain &&
593 param_definition_type !=
594 ParamDefinition::kParameterDefinitionReconGain) {
595 return absl::InvalidArgumentError(
596 absl::StrCat("Unsupported parameter type: ", *param_definition_type));
597 }
598 }
599
600 return absl::OkStatus();
601 }
602
AddMetadata(const iamf_tools_cli_proto::ParameterBlockObuMetadata & parameter_block_metadata)603 absl::Status ParameterBlockGenerator::AddMetadata(
604 const iamf_tools_cli_proto::ParameterBlockObuMetadata&
605 parameter_block_metadata) {
606 const auto& param_definition_iter =
607 param_definition_variants_.find(parameter_block_metadata.parameter_id());
608 if (param_definition_iter == param_definition_variants_.end()) {
609 return absl::InvalidArgumentError(
610 absl::StrCat("No parameter definition found for parameter ID= ",
611 parameter_block_metadata.parameter_id()));
612 }
613 const auto& param_definition_type = std::visit(
614 [](const auto& param_definition) { return param_definition.GetType(); },
615 param_definition_iter->second);
616 RETURN_IF_NOT_OK(
617 ValidateHasValue(param_definition_type, "`param_definition_type`."));
618 typed_proto_metadata_[*param_definition_type].push_back(
619 parameter_block_metadata);
620
621 return absl::OkStatus();
622 }
623
GenerateDemixing(GlobalTimingModule & global_timing_module,std::list<ParameterBlockWithData> & output_parameter_blocks)624 absl::Status ParameterBlockGenerator::GenerateDemixing(
625 GlobalTimingModule& global_timing_module,
626 std::list<ParameterBlockWithData>& output_parameter_blocks) {
627 RETURN_IF_NOT_OK(GenerateParameterBlocks(
628 /*id_to_labeled_frame=*/nullptr,
629 /*id_to_labeled_decoded_frame=*/nullptr,
630 typed_proto_metadata_[ParamDefinition::kParameterDefinitionDemixing],
631 global_timing_module, output_parameter_blocks));
632
633 return absl::OkStatus();
634 }
635
GenerateMixGain(GlobalTimingModule & global_timing_module,std::list<ParameterBlockWithData> & output_parameter_blocks)636 absl::Status ParameterBlockGenerator::GenerateMixGain(
637 GlobalTimingModule& global_timing_module,
638 std::list<ParameterBlockWithData>& output_parameter_blocks) {
639 RETURN_IF_NOT_OK(GenerateParameterBlocks(
640 /*id_to_labeled_frame=*/nullptr,
641 /*id_to_labeled_decoded_frame=*/nullptr,
642 typed_proto_metadata_[ParamDefinition::kParameterDefinitionMixGain],
643 global_timing_module, output_parameter_blocks));
644
645 return absl::OkStatus();
646 }
647
648 // TODO(b/306319126): Generate Recon Gain iteratively now that the audio frame
649 // decoder decodes iteratively.
GenerateReconGain(const IdLabeledFrameMap & id_to_labeled_frame,const IdLabeledFrameMap & id_to_labeled_decoded_frame,GlobalTimingModule & global_timing_module,std::list<ParameterBlockWithData> & output_parameter_blocks)650 absl::Status ParameterBlockGenerator::GenerateReconGain(
651 const IdLabeledFrameMap& id_to_labeled_frame,
652 const IdLabeledFrameMap& id_to_labeled_decoded_frame,
653 GlobalTimingModule& global_timing_module,
654 std::list<ParameterBlockWithData>& output_parameter_blocks) {
655 RETURN_IF_NOT_OK(GenerateParameterBlocks(
656 &id_to_labeled_frame, &id_to_labeled_decoded_frame,
657 typed_proto_metadata_[ParamDefinition::kParameterDefinitionReconGain],
658 global_timing_module, output_parameter_blocks));
659 return absl::OkStatus();
660 }
661
GenerateParameterBlocks(const IdLabeledFrameMap * id_to_labeled_frame,const IdLabeledFrameMap * id_to_labeled_decoded_frame,std::list<iamf_tools_cli_proto::ParameterBlockObuMetadata> & proto_metadata_list,GlobalTimingModule & global_timing_module,std::list<ParameterBlockWithData> & output_parameter_blocks)662 absl::Status ParameterBlockGenerator::GenerateParameterBlocks(
663 const IdLabeledFrameMap* id_to_labeled_frame,
664 const IdLabeledFrameMap* id_to_labeled_decoded_frame,
665 std::list<iamf_tools_cli_proto::ParameterBlockObuMetadata>&
666 proto_metadata_list,
667 GlobalTimingModule& global_timing_module,
668 std::list<ParameterBlockWithData>& output_parameter_blocks) {
669 for (auto& parameter_block_metadata : proto_metadata_list) {
670 ParameterBlockWithData output_parameter_block;
671 const auto& param_definition_variant =
672 param_definition_variants_.at(parameter_block_metadata.parameter_id());
673 const auto* param_definition_base = std::visit(
674 [](const auto& param_definition) {
675 return static_cast<const ParamDefinition*>(¶m_definition);
676 },
677 param_definition_variant);
678 RETURN_IF_NOT_OK(
679 PopulateCommonFields(parameter_block_metadata, *param_definition_base,
680 global_timing_module, output_parameter_block));
681
682 RETURN_IF_NOT_OK(PopulateSubblocks(
683 parameter_block_metadata, override_computed_recon_gains_,
684 additional_recon_gains_logging_, id_to_labeled_frame,
685 id_to_labeled_decoded_frame, param_definition_variant,
686 output_parameter_block));
687
688 // Disable some verbose logging after the first recon gain block is
689 // produced.
690 if (!override_computed_recon_gains_) {
691 additional_recon_gains_logging_ = false;
692 }
693
694 output_parameter_blocks.push_back(std::move(output_parameter_block));
695 }
696
697 RETURN_IF_NOT_OK(LogParameterBlockObus(output_parameter_blocks));
698
699 // Clear the metadata of this frame.
700 proto_metadata_list.clear();
701
702 return absl::OkStatus();
703 }
704
705 } // namespace iamf_tools
706