• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.server.voiceinteraction;
18 
19 import android.annotation.NonNull;
20 import android.hardware.soundtrigger.SoundTrigger.Keyphrase;
21 import android.hardware.soundtrigger.SoundTrigger.KeyphraseSoundModel;
22 
23 import java.io.PrintWriter;
24 import java.util.Arrays;
25 import java.util.HashMap;
26 import java.util.List;
27 import java.util.Map;
28 import java.util.Objects;
29 import java.util.StringJoiner;
30 
31 /**
32  * In memory model enrollment database for testing purposes.
33  * @hide
34  */
35 public class TestModelEnrollmentDatabase implements IEnrolledModelDb {
36 
37     // Record representing the primary key used in the real model database.
38     private static final class EnrollmentKey {
39         private final int mKeyphraseId;
40         private final List<Integer> mUserIds;
41         private final String mLocale;
42 
EnrollmentKey(int keyphraseId, @NonNull List<Integer> userIds, @NonNull String locale)43         EnrollmentKey(int keyphraseId,
44                 @NonNull List<Integer> userIds, @NonNull String locale) {
45             mKeyphraseId = keyphraseId;
46             mUserIds = Objects.requireNonNull(userIds);
47             mLocale = Objects.requireNonNull(locale);
48         }
49 
keyphraseId()50         int keyphraseId() {
51             return mKeyphraseId;
52         }
53 
userIds()54         List<Integer> userIds() {
55             return mUserIds;
56         }
57 
locale()58         String locale() {
59             return mLocale;
60         }
61 
62         @Override
toString()63         public String toString() {
64             StringJoiner sj = new StringJoiner(", ", "{", "}");
65             sj.add("keyphraseId: " + mKeyphraseId);
66             sj.add("userIds: " + mUserIds.toString());
67             sj.add("locale: " + mLocale.toString());
68             return "EnrollmentKey: " + sj.toString();
69         }
70 
71         @Override
hashCode()72         public int hashCode() {
73             final int prime = 31;
74             int res = 1;
75             res = prime * res + mKeyphraseId;
76             res = prime * res + mUserIds.hashCode();
77             res = prime * res + mLocale.hashCode();
78             return res;
79         }
80 
81         @Override
equals(Object other)82         public boolean equals(Object other) {
83             if (this == other) return true;
84             if (other == null) return false;
85             if (!(other instanceof EnrollmentKey)) return false;
86             EnrollmentKey that = (EnrollmentKey) other;
87             if (mKeyphraseId != that.mKeyphraseId) return false;
88             if (!mUserIds.equals(that.mUserIds)) return false;
89             if (!mLocale.equals(that.mLocale)) return false;
90             return true;
91         }
92 
93     }
94 
95     private final Map<EnrollmentKey, KeyphraseSoundModel> mModelMap = new HashMap<>();
96 
97     @Override
updateKeyphraseSoundModel(KeyphraseSoundModel soundModel)98     public boolean updateKeyphraseSoundModel(KeyphraseSoundModel soundModel) {
99         final Keyphrase keyphrase = soundModel.getKeyphrases()[0];
100         mModelMap.put(new EnrollmentKey(keyphrase.getId(),
101                         Arrays.stream(keyphrase.getUsers()).boxed().toList(),
102                         keyphrase.getLocale().toLanguageTag()),
103                     soundModel);
104         return true;
105     }
106 
107     @Override
deleteKeyphraseSoundModel(int keyphraseId, int userHandle, String bcp47Locale)108     public boolean deleteKeyphraseSoundModel(int keyphraseId, int userHandle, String bcp47Locale) {
109         return mModelMap.keySet().removeIf(key -> (key.keyphraseId() == keyphraseId)
110                 && key.locale().equals(bcp47Locale)
111                 && key.userIds().contains(userHandle));
112     }
113 
114     @Override
getKeyphraseSoundModel(int keyphraseId, int userHandle, String bcp47Locale)115     public KeyphraseSoundModel getKeyphraseSoundModel(int keyphraseId, int userHandle,
116             String bcp47Locale) {
117         return mModelMap.entrySet()
118                 .stream()
119                 .filter((entry) -> (entry.getKey().keyphraseId() == keyphraseId)
120                         && entry.getKey().locale().equals(bcp47Locale)
121                         && entry.getKey().userIds().contains(userHandle))
122                 .findFirst()
123                 .map((entry) -> entry.getValue())
124                 .orElse(null);
125     }
126 
127     @Override
getKeyphraseSoundModel(String keyphrase, int userHandle, String bcp47Locale)128     public KeyphraseSoundModel getKeyphraseSoundModel(String keyphrase, int userHandle,
129             String bcp47Locale) {
130         return mModelMap.entrySet()
131                 .stream()
132                 .filter((entry) -> (entry.getValue().getKeyphrases()[0].getText().equals(keyphrase)
133                         && entry.getKey().locale().equals(bcp47Locale)
134                         && entry.getKey().userIds().contains(userHandle)))
135                 .findFirst()
136                 .map((entry) -> entry.getValue())
137                 .orElse(null);
138     }
139 
140 
141     /**
142      * Dumps contents of database for dumpsys
143      */
dump(PrintWriter pw)144     public void dump(PrintWriter pw) {
145         pw.println("Using test enrollment database, with enrolled models:");
146         pw.println(mModelMap);
147     }
148 }
149