• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // #define LOG_NDEBUG 0
18 #define LOG_TAG "audio_utils_MelAggregator"
19 
20 #include <audio_utils/MelAggregator.h>
21 #include <audio_utils/power.h>
22 #include <cinttypes>
23 #include <iterator>
24 #include <utils/Log.h>
25 
26 namespace android::audio_utils {
27 namespace {
28 
29 /** Min value after which the MEL values are aggregated to CSD. */
30 constexpr float kMinCsdRecordToStore = 0.01f;
31 
32 /** Threshold for 100% CSD expressed in Pa^2s. */
33 constexpr float kCsdThreshold = 5760.0f; // 1.6f(Pa^2h) * 3600.0f(s);
34 
35 /** Reference energy used for dB calculation in Pa^2. */
36 constexpr float kReferenceEnergyPa = 4e-10;
37 
38 /**
39  * Checking the intersection of the time intervals of v1 and v2. Each MelRecord v
40  * spawns an interval [t1, t2) if and only if:
41  *    v.timestamp == t1 && v.mels.size() == t2 - t1
42  **/
intersectRegion(const MelRecord & v1,const MelRecord & v2)43 std::pair<int64_t, int64_t> intersectRegion(const MelRecord& v1, const MelRecord& v2)
44 {
45     const int64_t maxStart = std::max(v1.timestamp, v2.timestamp);
46     const int64_t v1End = v1.timestamp + v1.mels.size();
47     const int64_t v2End = v2.timestamp + v2.mels.size();
48     const int64_t minEnd = std::min(v1End, v2End);
49     return {maxStart, minEnd};
50 }
51 
aggregateMels(const float mel1,const float mel2)52 float aggregateMels(const float mel1, const float mel2) {
53     return audio_utils_power_from_energy(powf(10.f, mel1 / 10.f) + powf(10.f, mel2 / 10.f));
54 }
55 
averageMelEnergy(const float mel1,const int64_t duration1,const float mel2,const int64_t duration2)56 float averageMelEnergy(const float mel1,
57                        const int64_t duration1,
58                        const float mel2,
59                        const int64_t duration2) {
60     return audio_utils_power_from_energy((powf(10.f, mel1 / 10.f) * duration1
61         + powf(10.f, mel2 / 10.f) * duration2) / (duration1 + duration2));
62 }
63 
melToCsd(float mel)64 float melToCsd(float mel) {
65     float energy = powf(10.f, mel / 10.0f);
66     return kReferenceEnergyPa * energy / kCsdThreshold;
67 }
68 
createRevertedRecord(const CsdRecord & record)69 CsdRecord createRevertedRecord(const CsdRecord& record) {
70     return {record.timestamp, record.duration, -record.value, record.averageMel};
71 }
72 
73 }  // namespace
74 
csdTimeIntervalStored_l()75 int64_t MelAggregator::csdTimeIntervalStored_l()
76 {
77     return mCsdRecords.rbegin()->second.timestamp + mCsdRecords.rbegin()->second.duration
78         - mCsdRecords.begin()->second.timestamp;
79 }
80 
addNewestCsdRecord_l(int64_t timestamp,int64_t duration,float csdRecord,float averageMel)81 std::map<int64_t, CsdRecord>::iterator MelAggregator::addNewestCsdRecord_l(int64_t timestamp,
82                                                                            int64_t duration,
83                                                                            float csdRecord,
84                                                                            float averageMel)
85 {
86     ALOGV("%s: add new csd[%" PRId64 ", %" PRId64 "]=%f for MEL avg %f",
87                       __func__,
88                       timestamp,
89                       duration,
90                       csdRecord,
91                       averageMel);
92 
93     mCurrentCsd += csdRecord;
94     return mCsdRecords.emplace_hint(mCsdRecords.end(),
95                                     timestamp,
96                                     CsdRecord(timestamp,
97                                               duration,
98                                               csdRecord,
99                                               averageMel));
100 }
101 
removeOldCsdRecords_l(std::vector<CsdRecord> & removeRecords)102 void MelAggregator::removeOldCsdRecords_l(std::vector<CsdRecord>& removeRecords) {
103     // Remove older CSD values
104     while (!mCsdRecords.empty() && csdTimeIntervalStored_l() > mCsdWindowSeconds) {
105         mCurrentCsd -= mCsdRecords.begin()->second.value;
106         removeRecords.emplace_back(createRevertedRecord(mCsdRecords.begin()->second));
107         mCsdRecords.erase(mCsdRecords.begin());
108     }
109 }
110 
updateCsdRecords_l()111 std::vector<CsdRecord> MelAggregator::updateCsdRecords_l()
112 {
113     std::vector<CsdRecord> newRecords;
114 
115     // only update if we are above threshold
116     if (mCurrentMelRecordsCsd < kMinCsdRecordToStore) {
117         removeOldCsdRecords_l(newRecords);
118         return newRecords;
119     }
120 
121     float converted = 0.f;
122     float averageMel = 0.f;
123     float csdValue = 0.f;
124     int64_t duration = 0;
125     int64_t timestamp = mMelRecords.begin()->first;
126     for (const auto& storedMel: mMelRecords) {
127         int melsIdx = 0;
128         for (const auto& mel: storedMel.second.mels) {
129             averageMel = averageMelEnergy(averageMel, duration, mel, 1.f);
130             csdValue += melToCsd(mel);
131             ++duration;
132             if (csdValue >= kMinCsdRecordToStore
133                 && mCurrentMelRecordsCsd - converted - csdValue >= kMinCsdRecordToStore) {
134                 auto it = addNewestCsdRecord_l(timestamp,
135                                                duration,
136                                                csdValue,
137                                                averageMel);
138                 newRecords.emplace_back(it->second);
139 
140                 duration = 0;
141                 averageMel = 0.f;
142                 converted += csdValue;
143                 csdValue = 0.f;
144                 timestamp = storedMel.first + melsIdx;
145             }
146             ++ melsIdx;
147         }
148     }
149 
150     if(csdValue > 0) {
151         auto it = addNewestCsdRecord_l(timestamp,
152                                        duration,
153                                        csdValue,
154                                        averageMel);
155         newRecords.emplace_back(it->second);
156     }
157 
158     removeOldCsdRecords_l(newRecords);
159 
160     // reset mel values
161     mCurrentMelRecordsCsd = 0.0f;
162     mMelRecords.clear();
163 
164     return newRecords;
165 }
166 
aggregateAndAddNewMelRecord(const MelRecord & mel)167 std::vector<CsdRecord> MelAggregator::aggregateAndAddNewMelRecord(const MelRecord& mel)
168 {
169     std::lock_guard _l(mLock);
170     return aggregateAndAddNewMelRecord_l(mel);
171 }
172 
aggregateAndAddNewMelRecord_l(const MelRecord & mel)173 std::vector<CsdRecord> MelAggregator::aggregateAndAddNewMelRecord_l(const MelRecord& mel)
174 {
175     for (const auto& m : mel.mels) {
176         mCurrentMelRecordsCsd += melToCsd(m);
177     }
178     ALOGV("%s: current mel values CSD %f", __func__, mCurrentMelRecordsCsd);
179 
180     auto mergeIt = mMelRecords.lower_bound(mel.timestamp);
181 
182     if (mergeIt != mMelRecords.begin()) {
183         auto prevMergeIt = std::prev(mergeIt);
184         if (prevMergeIt->second.overlapsEnd(mel)) {
185             mergeIt = prevMergeIt;
186         }
187     }
188 
189     int64_t newTimestamp = mel.timestamp;
190     std::vector<float> newMels = mel.mels;
191     auto mergeStart = mergeIt;
192     int overlapStart = 0;
193     while(mergeIt != mMelRecords.end()) {
194         const auto& [melRecordStart, melRecord] = *mergeIt;
195         const auto [regionStart, regionEnd] = intersectRegion(melRecord, mel);
196         if (regionStart >= regionEnd) {
197             // no intersection
198             break;
199         }
200 
201         if (melRecordStart < regionStart) {
202             newTimestamp = melRecordStart;
203             overlapStart = regionStart - melRecordStart;
204             newMels.insert(newMels.begin(), melRecord.mels.begin(),
205                            melRecord.mels.begin() + overlapStart);
206         }
207 
208         for (int64_t aggregateTime = regionStart; aggregateTime < regionEnd; ++aggregateTime) {
209             const int offsetStored = aggregateTime - melRecordStart;
210             const int offsetNew = aggregateTime - mel.timestamp;
211             newMels[overlapStart + offsetNew] =
212                 aggregateMels(melRecord.mels[offsetStored], mel.mels[offsetNew]);
213         }
214 
215         const int64_t mergeEndTime = melRecordStart + melRecord.mels.size();
216         if (mergeEndTime > regionEnd) {
217             newMels.insert(newMels.end(),
218                            melRecord.mels.end() - mergeEndTime + regionEnd,
219                            melRecord.mels.end());
220         }
221 
222         ++mergeIt;
223     }
224 
225     auto hint = mergeIt;
226     if (mergeStart != mergeIt) {
227         hint = mMelRecords.erase(mergeStart, mergeIt);
228     }
229 
230     mMelRecords.emplace_hint(hint,
231                              newTimestamp,
232                              MelRecord(mel.portId, newMels, newTimestamp));
233 
234     return updateCsdRecords_l();
235 }
236 
reset(float newCsd,const std::vector<CsdRecord> & newRecords)237 void MelAggregator::reset(float newCsd, const std::vector<CsdRecord>& newRecords)
238 {
239     std::lock_guard _l(mLock);
240     mCsdRecords.clear();
241     mMelRecords.clear();
242 
243     mCurrentCsd = newCsd;
244     for (const auto& record : newRecords) {
245         mCsdRecords.emplace_hint(mCsdRecords.end(), record.timestamp, record);
246     }
247 }
248 
getCachedMelRecordsSize() const249 size_t MelAggregator::getCachedMelRecordsSize() const
250 {
251     std::lock_guard _l(mLock);
252     return mMelRecords.size();
253 }
254 
foreachCachedMel(const std::function<void (const MelRecord &)> & f) const255 void MelAggregator::foreachCachedMel(const std::function<void(const MelRecord&)>& f) const
256 {
257      std::lock_guard _l(mLock);
258      for (const auto &melRecord : mMelRecords) {
259          f(melRecord.second);
260      }
261 }
262 
getCsd()263 float MelAggregator::getCsd() {
264     std::lock_guard _l(mLock);
265     return mCurrentCsd;
266 }
267 
getCsdRecordsSize() const268 size_t MelAggregator::getCsdRecordsSize() const {
269     std::lock_guard _l(mLock);
270     return mCsdRecords.size();
271 }
272 
foreachCsd(const std::function<void (const CsdRecord &)> & f) const273 void MelAggregator::foreachCsd(const std::function<void(const CsdRecord&)>& f) const
274 {
275      std::lock_guard _l(mLock);
276      for (const auto &csdRecord : mCsdRecords) {
277          f(csdRecord.second);
278      }
279 }
280 
281 }  // namespace android::audio_utils
282