1 /*
2 * Copyright 2022 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 // #define LOG_NDEBUG 0
18 #define LOG_TAG "audio_utils_MelAggregator"
19
20 #include <audio_utils/MelAggregator.h>
21 #include <audio_utils/power.h>
22 #include <cinttypes>
23 #include <iterator>
24 #include <utils/Log.h>
25
26 namespace android::audio_utils {
27 namespace {
28
29 /** Min value after which the MEL values are aggregated to CSD. */
30 constexpr float kMinCsdRecordToStore = 0.01f;
31
32 /** Threshold for 100% CSD expressed in Pa^2s. */
33 constexpr float kCsdThreshold = 5760.0f; // 1.6f(Pa^2h) * 3600.0f(s);
34
35 /** Reference energy used for dB calculation in Pa^2. */
36 constexpr float kReferenceEnergyPa = 4e-10;
37
38 /**
39 * Checking the intersection of the time intervals of v1 and v2. Each MelRecord v
40 * spawns an interval [t1, t2) if and only if:
41 * v.timestamp == t1 && v.mels.size() == t2 - t1
42 **/
intersectRegion(const MelRecord & v1,const MelRecord & v2)43 std::pair<int64_t, int64_t> intersectRegion(const MelRecord& v1, const MelRecord& v2)
44 {
45 const int64_t maxStart = std::max(v1.timestamp, v2.timestamp);
46 const int64_t v1End = v1.timestamp + v1.mels.size();
47 const int64_t v2End = v2.timestamp + v2.mels.size();
48 const int64_t minEnd = std::min(v1End, v2End);
49 return {maxStart, minEnd};
50 }
51
aggregateMels(const float mel1,const float mel2)52 float aggregateMels(const float mel1, const float mel2) {
53 return audio_utils_power_from_energy(powf(10.f, mel1 / 10.f) + powf(10.f, mel2 / 10.f));
54 }
55
averageMelEnergy(const float mel1,const int64_t duration1,const float mel2,const int64_t duration2)56 float averageMelEnergy(const float mel1,
57 const int64_t duration1,
58 const float mel2,
59 const int64_t duration2) {
60 return audio_utils_power_from_energy((powf(10.f, mel1 / 10.f) * duration1
61 + powf(10.f, mel2 / 10.f) * duration2) / (duration1 + duration2));
62 }
63
melToCsd(float mel)64 float melToCsd(float mel) {
65 float energy = powf(10.f, mel / 10.0f);
66 return kReferenceEnergyPa * energy / kCsdThreshold;
67 }
68
createRevertedRecord(const CsdRecord & record)69 CsdRecord createRevertedRecord(const CsdRecord& record) {
70 return {record.timestamp, record.duration, -record.value, record.averageMel};
71 }
72
73 } // namespace
74
csdTimeIntervalStored_l()75 int64_t MelAggregator::csdTimeIntervalStored_l()
76 {
77 return mCsdRecords.rbegin()->second.timestamp + mCsdRecords.rbegin()->second.duration
78 - mCsdRecords.begin()->second.timestamp;
79 }
80
addNewestCsdRecord_l(int64_t timestamp,int64_t duration,float csdRecord,float averageMel)81 std::map<int64_t, CsdRecord>::iterator MelAggregator::addNewestCsdRecord_l(int64_t timestamp,
82 int64_t duration,
83 float csdRecord,
84 float averageMel)
85 {
86 ALOGV("%s: add new csd[%" PRId64 ", %" PRId64 "]=%f for MEL avg %f",
87 __func__,
88 timestamp,
89 duration,
90 csdRecord,
91 averageMel);
92
93 mCurrentCsd += csdRecord;
94 return mCsdRecords.emplace_hint(mCsdRecords.end(),
95 timestamp,
96 CsdRecord(timestamp,
97 duration,
98 csdRecord,
99 averageMel));
100 }
101
removeOldCsdRecords_l(std::vector<CsdRecord> & removeRecords)102 void MelAggregator::removeOldCsdRecords_l(std::vector<CsdRecord>& removeRecords) {
103 // Remove older CSD values
104 while (!mCsdRecords.empty() && csdTimeIntervalStored_l() > mCsdWindowSeconds) {
105 mCurrentCsd -= mCsdRecords.begin()->second.value;
106 removeRecords.emplace_back(createRevertedRecord(mCsdRecords.begin()->second));
107 mCsdRecords.erase(mCsdRecords.begin());
108 }
109 }
110
updateCsdRecords_l()111 std::vector<CsdRecord> MelAggregator::updateCsdRecords_l()
112 {
113 std::vector<CsdRecord> newRecords;
114
115 // only update if we are above threshold
116 if (mCurrentMelRecordsCsd < kMinCsdRecordToStore) {
117 removeOldCsdRecords_l(newRecords);
118 return newRecords;
119 }
120
121 float converted = 0.f;
122 float averageMel = 0.f;
123 float csdValue = 0.f;
124 int64_t duration = 0;
125 int64_t timestamp = mMelRecords.begin()->first;
126 for (const auto& storedMel: mMelRecords) {
127 int melsIdx = 0;
128 for (const auto& mel: storedMel.second.mels) {
129 averageMel = averageMelEnergy(averageMel, duration, mel, 1.f);
130 csdValue += melToCsd(mel);
131 ++duration;
132 if (csdValue >= kMinCsdRecordToStore
133 && mCurrentMelRecordsCsd - converted - csdValue >= kMinCsdRecordToStore) {
134 auto it = addNewestCsdRecord_l(timestamp,
135 duration,
136 csdValue,
137 averageMel);
138 newRecords.emplace_back(it->second);
139
140 duration = 0;
141 averageMel = 0.f;
142 converted += csdValue;
143 csdValue = 0.f;
144 timestamp = storedMel.first + melsIdx;
145 }
146 ++ melsIdx;
147 }
148 }
149
150 if(csdValue > 0) {
151 auto it = addNewestCsdRecord_l(timestamp,
152 duration,
153 csdValue,
154 averageMel);
155 newRecords.emplace_back(it->second);
156 }
157
158 removeOldCsdRecords_l(newRecords);
159
160 // reset mel values
161 mCurrentMelRecordsCsd = 0.0f;
162 mMelRecords.clear();
163
164 return newRecords;
165 }
166
aggregateAndAddNewMelRecord(const MelRecord & mel)167 std::vector<CsdRecord> MelAggregator::aggregateAndAddNewMelRecord(const MelRecord& mel)
168 {
169 std::lock_guard _l(mLock);
170 return aggregateAndAddNewMelRecord_l(mel);
171 }
172
aggregateAndAddNewMelRecord_l(const MelRecord & mel)173 std::vector<CsdRecord> MelAggregator::aggregateAndAddNewMelRecord_l(const MelRecord& mel)
174 {
175 for (const auto& m : mel.mels) {
176 mCurrentMelRecordsCsd += melToCsd(m);
177 }
178 ALOGV("%s: current mel values CSD %f", __func__, mCurrentMelRecordsCsd);
179
180 auto mergeIt = mMelRecords.lower_bound(mel.timestamp);
181
182 if (mergeIt != mMelRecords.begin()) {
183 auto prevMergeIt = std::prev(mergeIt);
184 if (prevMergeIt->second.overlapsEnd(mel)) {
185 mergeIt = prevMergeIt;
186 }
187 }
188
189 int64_t newTimestamp = mel.timestamp;
190 std::vector<float> newMels = mel.mels;
191 auto mergeStart = mergeIt;
192 int overlapStart = 0;
193 while(mergeIt != mMelRecords.end()) {
194 const auto& [melRecordStart, melRecord] = *mergeIt;
195 const auto [regionStart, regionEnd] = intersectRegion(melRecord, mel);
196 if (regionStart >= regionEnd) {
197 // no intersection
198 break;
199 }
200
201 if (melRecordStart < regionStart) {
202 newTimestamp = melRecordStart;
203 overlapStart = regionStart - melRecordStart;
204 newMels.insert(newMels.begin(), melRecord.mels.begin(),
205 melRecord.mels.begin() + overlapStart);
206 }
207
208 for (int64_t aggregateTime = regionStart; aggregateTime < regionEnd; ++aggregateTime) {
209 const int offsetStored = aggregateTime - melRecordStart;
210 const int offsetNew = aggregateTime - mel.timestamp;
211 newMels[overlapStart + offsetNew] =
212 aggregateMels(melRecord.mels[offsetStored], mel.mels[offsetNew]);
213 }
214
215 const int64_t mergeEndTime = melRecordStart + melRecord.mels.size();
216 if (mergeEndTime > regionEnd) {
217 newMels.insert(newMels.end(),
218 melRecord.mels.end() - mergeEndTime + regionEnd,
219 melRecord.mels.end());
220 }
221
222 ++mergeIt;
223 }
224
225 auto hint = mergeIt;
226 if (mergeStart != mergeIt) {
227 hint = mMelRecords.erase(mergeStart, mergeIt);
228 }
229
230 mMelRecords.emplace_hint(hint,
231 newTimestamp,
232 MelRecord(mel.portId, newMels, newTimestamp));
233
234 return updateCsdRecords_l();
235 }
236
reset(float newCsd,const std::vector<CsdRecord> & newRecords)237 void MelAggregator::reset(float newCsd, const std::vector<CsdRecord>& newRecords)
238 {
239 std::lock_guard _l(mLock);
240 mCsdRecords.clear();
241 mMelRecords.clear();
242
243 mCurrentCsd = newCsd;
244 for (const auto& record : newRecords) {
245 mCsdRecords.emplace_hint(mCsdRecords.end(), record.timestamp, record);
246 }
247 }
248
getCachedMelRecordsSize() const249 size_t MelAggregator::getCachedMelRecordsSize() const
250 {
251 std::lock_guard _l(mLock);
252 return mMelRecords.size();
253 }
254
foreachCachedMel(const std::function<void (const MelRecord &)> & f) const255 void MelAggregator::foreachCachedMel(const std::function<void(const MelRecord&)>& f) const
256 {
257 std::lock_guard _l(mLock);
258 for (const auto &melRecord : mMelRecords) {
259 f(melRecord.second);
260 }
261 }
262
getCsd()263 float MelAggregator::getCsd() {
264 std::lock_guard _l(mLock);
265 return mCurrentCsd;
266 }
267
getCsdRecordsSize() const268 size_t MelAggregator::getCsdRecordsSize() const {
269 std::lock_guard _l(mLock);
270 return mCsdRecords.size();
271 }
272
foreachCsd(const std::function<void (const CsdRecord &)> & f) const273 void MelAggregator::foreachCsd(const std::function<void(const CsdRecord&)>& f) const
274 {
275 std::lock_guard _l(mLock);
276 for (const auto &csdRecord : mCsdRecords) {
277 f(csdRecord.second);
278 }
279 }
280
281 } // namespace android::audio_utils
282