• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 3-Clause Clear License
5  * and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
6  * License was not distributed with this source code in the LICENSE file, you
7  * can obtain it at www.aomedia.org/license/software-license/bsd-3-c-c. If the
8  * Alliance for Open Media Patent License 1.0 was not distributed with this
9  * source code in the PATENTS file, you can obtain it at
10  * www.aomedia.org/license/patent.
11  */
12 #include "iamf/cli/proto_conversion/proto_to_obu/parameter_block_generator.h"
13 
14 #include <cstdint>
15 #include <list>
16 #include <memory>
17 #include <optional>
18 #include <utility>
19 #include <variant>
20 #include <vector>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/log/check.h"
25 #include "absl/log/log.h"
26 #include "absl/status/status.h"
27 #include "absl/strings/str_cat.h"
28 #include "iamf/cli/audio_element_with_data.h"
29 #include "iamf/cli/channel_label.h"
30 #include "iamf/cli/cli_util.h"
31 #include "iamf/cli/demixing_module.h"
32 #include "iamf/cli/global_timing_module.h"
33 #include "iamf/cli/parameter_block_with_data.h"
34 #include "iamf/cli/proto/parameter_block.pb.h"
35 #include "iamf/cli/proto/parameter_data.pb.h"
36 #include "iamf/cli/proto_conversion/proto_utils.h"
37 #include "iamf/cli/recon_gain_generator.h"
38 #include "iamf/common/utils/macros.h"
39 #include "iamf/common/utils/numeric_utils.h"
40 #include "iamf/common/utils/validation_utils.h"
41 #include "iamf/obu/demixing_info_parameter_data.h"
42 #include "iamf/obu/demixing_param_definition.h"
43 #include "iamf/obu/mix_gain_parameter_data.h"
44 #include "iamf/obu/param_definition_variant.h"
45 #include "iamf/obu/param_definitions.h"
46 #include "iamf/obu/parameter_block.h"
47 #include "iamf/obu/recon_gain_info_parameter_data.h"
48 #include "iamf/obu/types.h"
49 
50 namespace iamf_tools {
51 
52 namespace {
53 
54 std::optional<ParamDefinition::ParameterDefinitionType>
GetParameterDefinitionType(const ParamDefinitionVariant & parameter_definition_variant)55 GetParameterDefinitionType(
56     const ParamDefinitionVariant& parameter_definition_variant) {
57   return std::visit(
58       [](const auto& param_definition) { return param_definition.GetType(); },
59       parameter_definition_variant);
60 }
61 
GetParameterDefinitionMode(const ParamDefinitionVariant & parameter_definition_variant)62 uint8_t GetParameterDefinitionMode(
63     const ParamDefinitionVariant& parameter_definition_variant) {
64   return std::visit(
65       [](const auto& param_definition) {
66         return param_definition.param_definition_mode_;
67       },
68       parameter_definition_variant);
69 }
70 
GenerateMixGainSubblock(const iamf_tools_cli_proto::MixGainParameterData & metadata_mix_gain_parameter_data,const MixGainParamDefinition * param_definition,std::unique_ptr<ParameterData> & parameter_data)71 absl::Status GenerateMixGainSubblock(
72     const iamf_tools_cli_proto::MixGainParameterData&
73         metadata_mix_gain_parameter_data,
74     const MixGainParamDefinition* param_definition,
75     std::unique_ptr<ParameterData>& parameter_data) {
76   parameter_data = param_definition->CreateParameterData();
77   auto* mix_gain_parameter_data =
78       static_cast<MixGainParameterData*>(parameter_data.get());
79   switch (metadata_mix_gain_parameter_data.animation_type()) {
80     using enum iamf_tools_cli_proto::AnimationType;
81     case ANIMATE_STEP: {
82       const auto& metadata_animation =
83           metadata_mix_gain_parameter_data.param_data().step();
84       mix_gain_parameter_data->animation_type =
85           MixGainParameterData::kAnimateStep;
86       AnimationStepInt16 obu_animation;
87       RETURN_IF_NOT_OK(StaticCastIfInRange<int32_t, int16_t>(
88           "AnimationStepInt16.start_point_value",
89           metadata_animation.start_point_value(),
90           obu_animation.start_point_value));
91       mix_gain_parameter_data->param_data = obu_animation;
92       break;
93     }
94     case ANIMATE_LINEAR: {
95       const auto& metadata_animation =
96           metadata_mix_gain_parameter_data.param_data().linear();
97       mix_gain_parameter_data->animation_type =
98           MixGainParameterData::kAnimateLinear;
99 
100       AnimationLinearInt16 obu_animation;
101       RETURN_IF_NOT_OK(StaticCastIfInRange<int32_t, int16_t>(
102           "AnimationLinearInt16.start_point_value",
103           metadata_animation.start_point_value(),
104           obu_animation.start_point_value));
105       RETURN_IF_NOT_OK(StaticCastIfInRange<int32_t, int16_t>(
106           "AnimationLinearInt16.end_point_value",
107           metadata_animation.end_point_value(), obu_animation.end_point_value));
108       mix_gain_parameter_data->param_data = obu_animation;
109       break;
110     }
111     case ANIMATE_BEZIER: {
112       const auto& metadata_animation =
113           metadata_mix_gain_parameter_data.param_data().bezier();
114       mix_gain_parameter_data->animation_type =
115           MixGainParameterData::kAnimateBezier;
116       AnimationBezierInt16 obu_animation;
117       RETURN_IF_NOT_OK(StaticCastIfInRange<int32_t, int16_t>(
118           "AnimationBezierInt16.start_point_value",
119           metadata_animation.start_point_value(),
120           obu_animation.start_point_value));
121       RETURN_IF_NOT_OK(StaticCastIfInRange<int32_t, int16_t>(
122           "AnimationBezierInt16.end_point_value",
123           metadata_animation.end_point_value(), obu_animation.end_point_value));
124       RETURN_IF_NOT_OK(StaticCastIfInRange<int32_t, int16_t>(
125           "AnimationBezierInt16.control_point_value",
126           metadata_animation.control_point_value(),
127           obu_animation.control_point_value));
128       RETURN_IF_NOT_OK(StaticCastIfInRange<uint32_t, uint8_t>(
129           "AnimationBezierInt16.control_point_relative_time",
130           metadata_animation.control_point_relative_time(),
131           obu_animation.control_point_relative_time));
132       mix_gain_parameter_data->param_data = obu_animation;
133       break;
134     }
135     default:
136       return absl::InvalidArgumentError(
137           absl::StrCat("Unrecognized animation type= ",
138                        metadata_mix_gain_parameter_data.animation_type()));
139   }
140 
141   return absl::OkStatus();
142 }
143 
FindDemixedChannels(const ChannelNumbers & accumulated_channels,const ChannelNumbers & layer_channels,std::list<ChannelLabel::Label> * const demixed_channel_labels)144 absl::Status FindDemixedChannels(
145     const ChannelNumbers& accumulated_channels,
146     const ChannelNumbers& layer_channels,
147     std::list<ChannelLabel::Label>* const demixed_channel_labels) {
148   using enum ChannelLabel::Label;
149   for (int surround = accumulated_channels.surround + 1;
150        surround <= layer_channels.surround; surround++) {
151     switch (surround) {
152       case 2:
153         // Previous layer is Mono, this layer is Stereo.
154         if (accumulated_channels.surround == 1) {
155           demixed_channel_labels->push_back(kDemixedR2);
156         }
157         break;
158       case 3:
159         demixed_channel_labels->push_back(kDemixedL3);
160         demixed_channel_labels->push_back(kDemixedR3);
161         break;
162       case 5:
163         demixed_channel_labels->push_back(kDemixedLs5);
164         demixed_channel_labels->push_back(kDemixedRs5);
165         break;
166       case 7:
167         demixed_channel_labels->push_back(kDemixedL7);
168         demixed_channel_labels->push_back(kDemixedR7);
169         demixed_channel_labels->push_back(kDemixedLrs7);
170         demixed_channel_labels->push_back(kDemixedRrs7);
171         break;
172       default:
173         if (surround > 7) {
174           return absl::InvalidArgumentError(absl::StrCat(
175               "Unsupported number of surround channels: ", surround));
176         }
177         break;
178     }
179   }
180 
181   if (accumulated_channels.height == 2) {
182     if (layer_channels.height == 4) {
183       demixed_channel_labels->push_back(kDemixedLtb4);
184       demixed_channel_labels->push_back(kDemixedRtb4);
185     } else if (layer_channels.height == 2 &&
186                accumulated_channels.surround == 3 &&
187                layer_channels.surround > 3) {
188       demixed_channel_labels->push_back(kDemixedLtf2);
189       demixed_channel_labels->push_back(kDemixedRtf2);
190     }
191   }
192 
193   return absl::OkStatus();
194 }
195 
ConvertReconGainsAndFlags(const bool additional_logging,const absl::flat_hash_map<ChannelLabel::Label,double> & label_to_recon_gain,std::vector<uint8_t> & computed_recon_gains,DecodedUleb128 & computed_recon_gain_flag)196 absl::Status ConvertReconGainsAndFlags(
197     const bool additional_logging,
198     const absl::flat_hash_map<ChannelLabel::Label, double>& label_to_recon_gain,
199     std::vector<uint8_t>& computed_recon_gains,
200     DecodedUleb128& computed_recon_gain_flag) {
201   computed_recon_gains.resize(12, 0);
202   computed_recon_gain_flag = 0;
203   for (const auto& [label, recon_gain] : label_to_recon_gain) {
204     LOG_IF(INFO, additional_logging)
205         << "Recon Gain[" << label << "]= " << recon_gain;
206 
207     // Bit position is based on Figure 5 of the Spec.
208     int bit_position = 0;
209     switch (label) {
210       using enum ChannelLabel::Label;
211       case kDemixedL7:
212       case kDemixedL5:
213       case kDemixedL3:
214         // `kDemixedL2` is never demixed.
215         bit_position = 0;
216         break;
217       case kDemixedR7:
218       case kDemixedR5:
219       case kDemixedR3:
220       case kDemixedR2:
221         // `kCentre` is never demixed. Skipping bit position = 1.
222         bit_position = 2;
223         break;
224       case kDemixedLs5:
225         bit_position = 3;
226         break;
227       case kDemixedRs5:
228         bit_position = 4;
229         break;
230       case kDemixedLtf2:
231         bit_position = 5;
232         break;
233       case kDemixedRtf2:
234         bit_position = 6;
235         break;
236       case kDemixedLrs7:
237         bit_position = 7;
238         break;
239       case kDemixedRrs7:
240         bit_position = 8;
241         break;
242       case kDemixedLtb4:
243         bit_position = 9;
244         break;
245       case kDemixedRtb4:
246         bit_position = 10;
247         // `kLFE` is never demixed. Skipping bit position = 11.
248         break;
249       default:
250         LOG(ERROR) << "Unrecognized demixed channel label: " << label;
251     }
252     computed_recon_gain_flag |= 1 << bit_position;
253     computed_recon_gains[bit_position] =
254         static_cast<uint8_t>(recon_gain * 255.0);
255   }
256   return absl::OkStatus();
257 }
258 
ComputeReconGains(const int layer_index,const ChannelNumbers & layer_channels,const ChannelNumbers & accumulated_channels,const bool additional_recon_gains_logging,const LabelSamplesMap & labeled_samples,const LabelSamplesMap & label_to_decoded_samples,const std::vector<bool> & recon_gain_is_present_flags,std::vector<uint8_t> & computed_recon_gains,DecodedUleb128 & computed_recon_gain_flag)259 absl::Status ComputeReconGains(
260     const int layer_index, const ChannelNumbers& layer_channels,
261     const ChannelNumbers& accumulated_channels,
262     const bool additional_recon_gains_logging,
263     const LabelSamplesMap& labeled_samples,
264     const LabelSamplesMap& label_to_decoded_samples,
265     const std::vector<bool>& recon_gain_is_present_flags,
266     std::vector<uint8_t>& computed_recon_gains,
267     DecodedUleb128& computed_recon_gain_flag) {
268   if (additional_recon_gains_logging) {
269     LogChannelNumbers(absl::StrCat("Layer[", layer_index, "]"), layer_channels);
270   }
271   absl::flat_hash_map<ChannelLabel::Label, double> label_to_recon_gain;
272   if (layer_index > 0) {
273     std::list<ChannelLabel::Label> demixed_channel_labels;
274     RETURN_IF_NOT_OK(FindDemixedChannels(accumulated_channels, layer_channels,
275                                          &demixed_channel_labels));
276 
277     LOG_IF(INFO, additional_recon_gains_logging) << "Demixed channels: ";
278     for (const auto& label : demixed_channel_labels) {
279       RETURN_IF_NOT_OK(ReconGainGenerator::ComputeReconGain(
280           label, labeled_samples, label_to_decoded_samples,
281           additional_recon_gains_logging, label_to_recon_gain[label]));
282     }
283   }
284 
285   if (recon_gain_is_present_flags[layer_index] !=
286       (!label_to_recon_gain.empty())) {
287     return absl::InvalidArgumentError(absl::StrCat(
288         "Mismatch of whether user specified recon gain is present: ",
289         recon_gain_is_present_flags[layer_index],
290         " vs whether recon gain should be computed: ",
291         !label_to_recon_gain.empty()));
292   }
293 
294   RETURN_IF_NOT_OK(ConvertReconGainsAndFlags(
295       /*additional_logging=*/true, label_to_recon_gain, computed_recon_gains,
296       computed_recon_gain_flag));
297 
298   return absl::OkStatus();
299 }
300 
GenerateReconGainSubblock(const bool override_computed_recon_gains,const bool additional_recon_gains_logging,const IdLabeledFrameMap & id_to_labeled_frame,const IdLabeledFrameMap & id_to_labeled_decoded_frame,const iamf_tools_cli_proto::ReconGainInfoParameterData & metadata_recon_gain_info_parameter_data,const ReconGainParamDefinition * param_definition,std::unique_ptr<ParameterData> & parameter_data)301 absl::Status GenerateReconGainSubblock(
302     const bool override_computed_recon_gains,
303     const bool additional_recon_gains_logging,
304     const IdLabeledFrameMap& id_to_labeled_frame,
305     const IdLabeledFrameMap& id_to_labeled_decoded_frame,
306     const iamf_tools_cli_proto::ReconGainInfoParameterData&
307         metadata_recon_gain_info_parameter_data,
308     const ReconGainParamDefinition* param_definition,
309     std::unique_ptr<ParameterData>& parameter_data) {
310   parameter_data = param_definition->CreateParameterData();
311   auto* recon_gain_info_parameter_data =
312       static_cast<ReconGainInfoParameterData*>(parameter_data.get());
313   const auto num_layers = param_definition->aux_data_.size();
314   const auto& user_recon_gains_layers =
315       metadata_recon_gain_info_parameter_data.recon_gains_for_layer();
316   if (num_layers > 1 && num_layers != user_recon_gains_layers.size()) {
317     return absl::InvalidArgumentError(
318         absl::StrCat("There are ", num_layers, " layers of scalable  ",
319                      "audio element, but the user only specifies ",
320                      user_recon_gains_layers.size(), " layers."));
321   }
322   recon_gain_info_parameter_data->recon_gain_elements.resize(num_layers);
323 
324   const std::vector<bool>& recon_gain_is_present_flags =
325       recon_gain_info_parameter_data->recon_gain_is_present_flags;
326   for (int layer_index = 0; layer_index < num_layers; layer_index++) {
327     // Write out the user supplied gains. Depending on the mode these either
328     // match the computed recon gains or are used as an override. Write to
329     // output.
330     auto& output_recon_gain_element =
331         recon_gain_info_parameter_data->recon_gain_elements[layer_index];
332     if (!param_definition->aux_data_[layer_index].recon_gain_is_present_flag) {
333       // Skip computation and store no value in the output.
334       output_recon_gain_element.reset();
335       continue;
336     }
337     output_recon_gain_element.emplace(ReconGainElement{});
338 
339     // Construct the bitmask indicating the channels where recon gains are
340     // present.
341     std::vector<uint8_t> user_recon_gains(12, 0);
342     DecodedUleb128 user_recon_gain_flag = 0;
343     for (const auto& [bit_position, user_recon_gain] :
344          user_recon_gains_layers[layer_index].recon_gain()) {
345       user_recon_gain_flag |= 1 << bit_position;
346       user_recon_gains[bit_position] = user_recon_gain;
347     }
348     for (const auto& [bit_position, user_recon_gain] :
349          user_recon_gains_layers[layer_index].recon_gain()) {
350       output_recon_gain_element->recon_gain[bit_position] =
351           user_recon_gains[bit_position];
352     }
353     output_recon_gain_element->recon_gain_flag = user_recon_gain_flag;
354 
355     if (override_computed_recon_gains) {
356       continue;
357     }
358 
359     // Compute the recon gains and validate they match the user supplied values.
360     std::vector<uint8_t> computed_recon_gains;
361     DecodedUleb128 computed_recon_gain_flag = 0;
362     const DecodedUleb128 audio_element_id = param_definition->audio_element_id_;
363     const auto labeled_frame_iter = id_to_labeled_frame.find(audio_element_id);
364     const auto labeled_decoded_frame_iter =
365         id_to_labeled_decoded_frame.find(audio_element_id);
366     if (labeled_frame_iter == id_to_labeled_frame.end() ||
367         labeled_decoded_frame_iter == id_to_labeled_decoded_frame.end()) {
368       return absl::InvalidArgumentError(absl::StrCat(
369           "Original or decoded audio frame for audio element ID= ",
370           audio_element_id, " not found when computing recon gains"));
371     }
372 
373     const auto& layer_channels =
374         param_definition->aux_data_[layer_index].channel_numbers_for_layer;
375     const auto accumulated_channels =
376         (layer_index > 0 ? param_definition->aux_data_[layer_index - 1]
377                                .channel_numbers_for_layer
378                          : ChannelNumbers{0, 0, 0});
379     RETURN_IF_NOT_OK(
380         ComputeReconGains(layer_index, layer_channels, accumulated_channels,
381                           additional_recon_gains_logging,
382                           labeled_frame_iter->second.label_to_samples,
383                           labeled_decoded_frame_iter->second.label_to_samples,
384                           recon_gain_is_present_flags, computed_recon_gains,
385                           computed_recon_gain_flag));
386 
387     // Compare computed and user specified flag and recon gain values.
388     if (computed_recon_gain_flag != user_recon_gain_flag) {
389       return absl::InvalidArgumentError(absl::StrCat(
390           "Computed recon gain flag different from what user specified: ",
391           computed_recon_gain_flag, " vs ", user_recon_gain_flag));
392     }
393     bool recon_gains_match = true;
394     for (int i = 0; i < 12; i++) {
395       if (user_recon_gains[i] != computed_recon_gains[i]) {
396         // Find all mismatches before returning an error.
397         LOG(ERROR) << "Computed recon gain [" << i
398                    << "] different from what user specified: "
399                    << absl::StrCat(computed_recon_gains[i]) << " vs "
400                    << absl::StrCat(user_recon_gains[i]);
401         recon_gains_match = false;
402       }
403     }
404     if (!recon_gains_match) {
405       return absl::InvalidArgumentError("Recon gains mismatch");
406     }
407   }  // End of for (int layer_index ...)
408 
409   return absl::OkStatus();
410 }
411 
GenerateParameterBlockSubblock(const bool override_computed_recon_gains,const bool additional_recon_gains_logging,const IdLabeledFrameMap * id_to_labeled_frame,const IdLabeledFrameMap * id_to_labeled_decoded_frame,const ParamDefinitionVariant & param_definition_variant,const bool include_subblock_duration,const int subblock_index,const iamf_tools_cli_proto::ParameterSubblock & metadata_subblock,ParameterBlockObu & obu)412 absl::Status GenerateParameterBlockSubblock(
413     const bool override_computed_recon_gains,
414     const bool additional_recon_gains_logging,
415     const IdLabeledFrameMap* id_to_labeled_frame,
416     const IdLabeledFrameMap* id_to_labeled_decoded_frame,
417     const ParamDefinitionVariant& param_definition_variant,
418     const bool include_subblock_duration, const int subblock_index,
419     const iamf_tools_cli_proto::ParameterSubblock& metadata_subblock,
420     ParameterBlockObu& obu) {
421   if (include_subblock_duration) {
422     RETURN_IF_NOT_OK(obu.SetSubblockDuration(
423         subblock_index, metadata_subblock.subblock_duration()));
424   }
425 
426   auto& obu_subblock_param_data = obu.subblocks_[subblock_index].param_data;
427   const auto param_definition_type =
428       GetParameterDefinitionType(param_definition_variant);
429   std::unique_ptr<ParameterData> parameter_data;
430   RETURN_IF_NOT_OK(
431       ValidateHasValue(param_definition_type, "`param_definition_type`."));
432   switch (*param_definition_type) {
433     using enum ParamDefinition::ParameterDefinitionType;
434     case kParameterDefinitionMixGain: {
435       auto* mix_gain_param_definition =
436           std::get_if<MixGainParamDefinition>(&param_definition_variant);
437       RETURN_IF_NOT_OK(
438           ValidateNotNull(mix_gain_param_definition, "MixGainParamDefinition"));
439       RETURN_IF_NOT_OK(
440           GenerateMixGainSubblock(metadata_subblock.mix_gain_parameter_data(),
441                                   mix_gain_param_definition, parameter_data));
442       break;
443     }
444     case kParameterDefinitionDemixing: {
445       if (subblock_index > 1) {
446         return absl::InvalidArgumentError(
447             "There should be only one subblock for demixing info.");
448       }
449       auto* demixing_param_definition =
450           std::get_if<DemixingParamDefinition>(&param_definition_variant);
451       RETURN_IF_NOT_OK(ValidateNotNull(demixing_param_definition,
452                                        "DemixingParamDefinition"));
453       parameter_data = demixing_param_definition->CreateParameterData();
454       RETURN_IF_NOT_OK(CopyDemixingInfoParameterData(
455           metadata_subblock.demixing_info_parameter_data(),
456           *static_cast<DemixingInfoParameterData*>(parameter_data.get())));
457       break;
458     }
459     case kParameterDefinitionReconGain: {
460       if (subblock_index > 1) {
461         return absl::InvalidArgumentError(
462             "There should be only one subblock for recon gain info.");
463       }
464       auto* recon_gain_param_definition =
465           std::get_if<ReconGainParamDefinition>(&param_definition_variant);
466       RETURN_IF_NOT_OK(ValidateNotNull(recon_gain_param_definition,
467                                        "ReconGainParamDefinition"));
468       RETURN_IF_NOT_OK(GenerateReconGainSubblock(
469           override_computed_recon_gains, additional_recon_gains_logging,
470           *id_to_labeled_frame, *id_to_labeled_decoded_frame,
471           metadata_subblock.recon_gain_info_parameter_data(),
472           recon_gain_param_definition, parameter_data));
473       break;
474     }
475     default:
476       // TODO(b/289080630): Support the extension fields here.
477       return absl::InvalidArgumentError(absl::StrCat(
478           "Unsupported param definition type= ", *param_definition_type));
479   }
480   obu_subblock_param_data = std::move(parameter_data);
481 
482   return absl::OkStatus();
483 }
484 
PopulateCommonFields(const iamf_tools_cli_proto::ParameterBlockObuMetadata & parameter_block_metadata,const ParamDefinition & param_definition,GlobalTimingModule & global_timing_module,ParameterBlockWithData & parameter_block_with_data)485 absl::Status PopulateCommonFields(
486     const iamf_tools_cli_proto::ParameterBlockObuMetadata&
487         parameter_block_metadata,
488     const ParamDefinition& param_definition,
489     GlobalTimingModule& global_timing_module,
490     ParameterBlockWithData& parameter_block_with_data) {
491   // Get the duration from the parameter definition or the OBU itself as
492   // applicable.
493   const DecodedUleb128 duration = param_definition.param_definition_mode_ == 1
494                                       ? parameter_block_metadata.duration()
495                                       : param_definition.duration_;
496 
497   // Populate the timing information.
498   RETURN_IF_NOT_OK(global_timing_module.GetNextParameterBlockTimestamps(
499       parameter_block_metadata.parameter_id(),
500       parameter_block_metadata.start_timestamp(), duration,
501       parameter_block_with_data.start_timestamp,
502       parameter_block_with_data.end_timestamp));
503 
504   // Populate the OBU.
505   const DecodedUleb128 parameter_id = parameter_block_metadata.parameter_id();
506   parameter_block_with_data.obu = std::make_unique<ParameterBlockObu>(
507       GetHeaderFromMetadata(parameter_block_metadata.obu_header()),
508       parameter_id, param_definition);
509 
510   // Several fields are dependent on `param_definition_mode`.
511   if (param_definition.param_definition_mode_ == 1) {
512     RETURN_IF_NOT_OK(parameter_block_with_data.obu->InitializeSubblocks(
513         parameter_block_metadata.duration(),
514         parameter_block_metadata.constant_subblock_duration(),
515         parameter_block_metadata.num_subblocks()));
516   } else {
517     RETURN_IF_NOT_OK(parameter_block_with_data.obu->InitializeSubblocks());
518   }
519 
520   return absl::OkStatus();
521 }
522 
PopulateSubblocks(const iamf_tools_cli_proto::ParameterBlockObuMetadata & parameter_block_metadata,const bool override_computed_recon_gains,const bool additional_recon_gains_logging,const IdLabeledFrameMap * id_to_labeled_frame,const IdLabeledFrameMap * id_to_labeled_decoded_frame,const ParamDefinitionVariant & param_definition_variant,ParameterBlockWithData & output_parameter_block)523 absl::Status PopulateSubblocks(
524     const iamf_tools_cli_proto::ParameterBlockObuMetadata&
525         parameter_block_metadata,
526     const bool override_computed_recon_gains,
527     const bool additional_recon_gains_logging,
528     const IdLabeledFrameMap* id_to_labeled_frame,
529     const IdLabeledFrameMap* id_to_labeled_decoded_frame,
530     const ParamDefinitionVariant& param_definition_variant,
531     ParameterBlockWithData& output_parameter_block) {
532   auto& parameter_block_obu = *output_parameter_block.obu;
533   const DecodedUleb128 num_subblocks = parameter_block_obu.GetNumSubblocks();
534 
535   // All subblocks will include `subblock_duration` or none will include it.
536   const bool include_subblock_duration =
537       GetParameterDefinitionMode(param_definition_variant) == 1 &&
538       parameter_block_obu.GetConstantSubblockDuration() == 0;
539 
540   if (num_subblocks != parameter_block_metadata.subblocks_size()) {
541     return absl::InvalidArgumentError(
542         absl::StrCat("Expected ", num_subblocks, " subblocks, got ",
543                      parameter_block_metadata.subblocks_size()));
544   }
545   for (int i = 0; i < num_subblocks; ++i) {
546     RETURN_IF_NOT_OK(GenerateParameterBlockSubblock(
547         override_computed_recon_gains, additional_recon_gains_logging,
548         id_to_labeled_frame, id_to_labeled_decoded_frame,
549         param_definition_variant, include_subblock_duration, i,
550         parameter_block_metadata.subblocks(i), parameter_block_obu));
551   }
552 
553   return absl::OkStatus();
554 }
555 
LogParameterBlockObus(const std::list<ParameterBlockWithData> & output_parameter_blocks)556 absl::Status LogParameterBlockObus(
557     const std::list<ParameterBlockWithData>& output_parameter_blocks) {
558   // Log only the first and the last parameter blocks.
559   if (output_parameter_blocks.empty()) {
560     return absl::OkStatus();
561   }
562   std::vector<const ParameterBlockWithData*> to_log = {
563       &output_parameter_blocks.front()};
564   if (output_parameter_blocks.size() > 1) {
565     to_log.push_back(&output_parameter_blocks.back());
566   }
567 
568   for (const auto* parameter_block_with_data : to_log) {
569     parameter_block_with_data->obu->PrintObu();
570     LOG(INFO) << "  // start_timestamp= "
571               << parameter_block_with_data->start_timestamp;
572     LOG(INFO) << "  // end_timestamp= "
573               << parameter_block_with_data->end_timestamp;
574   }
575 
576   return absl::OkStatus();
577 }
578 
579 }  // namespace
580 
Initialize(const absl::flat_hash_map<DecodedUleb128,AudioElementWithData> & audio_elements)581 absl::Status ParameterBlockGenerator::Initialize(
582     const absl::flat_hash_map<DecodedUleb128, AudioElementWithData>&
583         audio_elements) {
584   for (const auto& [parameter_id, param_definition_variant] :
585        param_definition_variants_) {
586     const auto param_definition_type =
587         GetParameterDefinitionType(param_definition_variant);
588     RETURN_IF_NOT_OK(
589         ValidateHasValue(param_definition_type, "param_definition_type"));
590     if (param_definition_type !=
591             ParamDefinition::kParameterDefinitionDemixing &&
592         param_definition_type != ParamDefinition::kParameterDefinitionMixGain &&
593         param_definition_type !=
594             ParamDefinition::kParameterDefinitionReconGain) {
595       return absl::InvalidArgumentError(
596           absl::StrCat("Unsupported parameter type: ", *param_definition_type));
597     }
598   }
599 
600   return absl::OkStatus();
601 }
602 
AddMetadata(const iamf_tools_cli_proto::ParameterBlockObuMetadata & parameter_block_metadata)603 absl::Status ParameterBlockGenerator::AddMetadata(
604     const iamf_tools_cli_proto::ParameterBlockObuMetadata&
605         parameter_block_metadata) {
606   const auto& param_definition_iter =
607       param_definition_variants_.find(parameter_block_metadata.parameter_id());
608   if (param_definition_iter == param_definition_variants_.end()) {
609     return absl::InvalidArgumentError(
610         absl::StrCat("No parameter definition found for parameter ID= ",
611                      parameter_block_metadata.parameter_id()));
612   }
613   const auto& param_definition_type = std::visit(
614       [](const auto& param_definition) { return param_definition.GetType(); },
615       param_definition_iter->second);
616   RETURN_IF_NOT_OK(
617       ValidateHasValue(param_definition_type, "`param_definition_type`."));
618   typed_proto_metadata_[*param_definition_type].push_back(
619       parameter_block_metadata);
620 
621   return absl::OkStatus();
622 }
623 
GenerateDemixing(GlobalTimingModule & global_timing_module,std::list<ParameterBlockWithData> & output_parameter_blocks)624 absl::Status ParameterBlockGenerator::GenerateDemixing(
625     GlobalTimingModule& global_timing_module,
626     std::list<ParameterBlockWithData>& output_parameter_blocks) {
627   RETURN_IF_NOT_OK(GenerateParameterBlocks(
628       /*id_to_labeled_frame=*/nullptr,
629       /*id_to_labeled_decoded_frame=*/nullptr,
630       typed_proto_metadata_[ParamDefinition::kParameterDefinitionDemixing],
631       global_timing_module, output_parameter_blocks));
632 
633   return absl::OkStatus();
634 }
635 
GenerateMixGain(GlobalTimingModule & global_timing_module,std::list<ParameterBlockWithData> & output_parameter_blocks)636 absl::Status ParameterBlockGenerator::GenerateMixGain(
637     GlobalTimingModule& global_timing_module,
638     std::list<ParameterBlockWithData>& output_parameter_blocks) {
639   RETURN_IF_NOT_OK(GenerateParameterBlocks(
640       /*id_to_labeled_frame=*/nullptr,
641       /*id_to_labeled_decoded_frame=*/nullptr,
642       typed_proto_metadata_[ParamDefinition::kParameterDefinitionMixGain],
643       global_timing_module, output_parameter_blocks));
644 
645   return absl::OkStatus();
646 }
647 
648 // TODO(b/306319126): Generate Recon Gain iteratively now that the audio frame
649 //                    decoder decodes iteratively.
GenerateReconGain(const IdLabeledFrameMap & id_to_labeled_frame,const IdLabeledFrameMap & id_to_labeled_decoded_frame,GlobalTimingModule & global_timing_module,std::list<ParameterBlockWithData> & output_parameter_blocks)650 absl::Status ParameterBlockGenerator::GenerateReconGain(
651     const IdLabeledFrameMap& id_to_labeled_frame,
652     const IdLabeledFrameMap& id_to_labeled_decoded_frame,
653     GlobalTimingModule& global_timing_module,
654     std::list<ParameterBlockWithData>& output_parameter_blocks) {
655   RETURN_IF_NOT_OK(GenerateParameterBlocks(
656       &id_to_labeled_frame, &id_to_labeled_decoded_frame,
657       typed_proto_metadata_[ParamDefinition::kParameterDefinitionReconGain],
658       global_timing_module, output_parameter_blocks));
659   return absl::OkStatus();
660 }
661 
GenerateParameterBlocks(const IdLabeledFrameMap * id_to_labeled_frame,const IdLabeledFrameMap * id_to_labeled_decoded_frame,std::list<iamf_tools_cli_proto::ParameterBlockObuMetadata> & proto_metadata_list,GlobalTimingModule & global_timing_module,std::list<ParameterBlockWithData> & output_parameter_blocks)662 absl::Status ParameterBlockGenerator::GenerateParameterBlocks(
663     const IdLabeledFrameMap* id_to_labeled_frame,
664     const IdLabeledFrameMap* id_to_labeled_decoded_frame,
665     std::list<iamf_tools_cli_proto::ParameterBlockObuMetadata>&
666         proto_metadata_list,
667     GlobalTimingModule& global_timing_module,
668     std::list<ParameterBlockWithData>& output_parameter_blocks) {
669   for (auto& parameter_block_metadata : proto_metadata_list) {
670     ParameterBlockWithData output_parameter_block;
671     const auto& param_definition_variant =
672         param_definition_variants_.at(parameter_block_metadata.parameter_id());
673     const auto* param_definition_base = std::visit(
674         [](const auto& param_definition) {
675           return static_cast<const ParamDefinition*>(&param_definition);
676         },
677         param_definition_variant);
678     RETURN_IF_NOT_OK(
679         PopulateCommonFields(parameter_block_metadata, *param_definition_base,
680                              global_timing_module, output_parameter_block));
681 
682     RETURN_IF_NOT_OK(PopulateSubblocks(
683         parameter_block_metadata, override_computed_recon_gains_,
684         additional_recon_gains_logging_, id_to_labeled_frame,
685         id_to_labeled_decoded_frame, param_definition_variant,
686         output_parameter_block));
687 
688     // Disable some verbose logging after the first recon gain block is
689     // produced.
690     if (!override_computed_recon_gains_) {
691       additional_recon_gains_logging_ = false;
692     }
693 
694     output_parameter_blocks.push_back(std::move(output_parameter_block));
695   }
696 
697   RETURN_IF_NOT_OK(LogParameterBlockObus(output_parameter_blocks));
698 
699   // Clear the metadata of this frame.
700   proto_metadata_list.clear();
701 
702   return absl::OkStatus();
703 }
704 
705 }  // namespace iamf_tools
706