1 /** 2 * Copyright (C) 2023 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.server.voiceinteraction; 18 19 import android.annotation.NonNull; 20 import android.hardware.soundtrigger.SoundTrigger.Keyphrase; 21 import android.hardware.soundtrigger.SoundTrigger.KeyphraseSoundModel; 22 23 import java.io.PrintWriter; 24 import java.util.Arrays; 25 import java.util.HashMap; 26 import java.util.List; 27 import java.util.Map; 28 import java.util.Objects; 29 import java.util.StringJoiner; 30 31 /** 32 * In memory model enrollment database for testing purposes. 33 * @hide 34 */ 35 public class TestModelEnrollmentDatabase implements IEnrolledModelDb { 36 37 // Record representing the primary key used in the real model database. 38 private static final class EnrollmentKey { 39 private final int mKeyphraseId; 40 private final List<Integer> mUserIds; 41 private final String mLocale; 42 EnrollmentKey(int keyphraseId, @NonNull List<Integer> userIds, @NonNull String locale)43 EnrollmentKey(int keyphraseId, 44 @NonNull List<Integer> userIds, @NonNull String locale) { 45 mKeyphraseId = keyphraseId; 46 mUserIds = Objects.requireNonNull(userIds); 47 mLocale = Objects.requireNonNull(locale); 48 } 49 keyphraseId()50 int keyphraseId() { 51 return mKeyphraseId; 52 } 53 userIds()54 List<Integer> userIds() { 55 return mUserIds; 56 } 57 locale()58 String locale() { 59 return mLocale; 60 } 61 62 @Override toString()63 public String toString() { 64 StringJoiner sj = new StringJoiner(", ", "{", "}"); 65 sj.add("keyphraseId: " + mKeyphraseId); 66 sj.add("userIds: " + mUserIds.toString()); 67 sj.add("locale: " + mLocale.toString()); 68 return "EnrollmentKey: " + sj.toString(); 69 } 70 71 @Override hashCode()72 public int hashCode() { 73 final int prime = 31; 74 int res = 1; 75 res = prime * res + mKeyphraseId; 76 res = prime * res + mUserIds.hashCode(); 77 res = prime * res + mLocale.hashCode(); 78 return res; 79 } 80 81 @Override equals(Object other)82 public boolean equals(Object other) { 83 if (this == other) return true; 84 if (other == null) return false; 85 if (!(other instanceof EnrollmentKey)) return false; 86 EnrollmentKey that = (EnrollmentKey) other; 87 if (mKeyphraseId != that.mKeyphraseId) return false; 88 if (!mUserIds.equals(that.mUserIds)) return false; 89 if (!mLocale.equals(that.mLocale)) return false; 90 return true; 91 } 92 93 } 94 95 private final Map<EnrollmentKey, KeyphraseSoundModel> mModelMap = new HashMap<>(); 96 97 @Override updateKeyphraseSoundModel(KeyphraseSoundModel soundModel)98 public boolean updateKeyphraseSoundModel(KeyphraseSoundModel soundModel) { 99 final Keyphrase keyphrase = soundModel.getKeyphrases()[0]; 100 mModelMap.put(new EnrollmentKey(keyphrase.getId(), 101 Arrays.stream(keyphrase.getUsers()).boxed().toList(), 102 keyphrase.getLocale().toLanguageTag()), 103 soundModel); 104 return true; 105 } 106 107 @Override deleteKeyphraseSoundModel(int keyphraseId, int userHandle, String bcp47Locale)108 public boolean deleteKeyphraseSoundModel(int keyphraseId, int userHandle, String bcp47Locale) { 109 return mModelMap.keySet().removeIf(key -> (key.keyphraseId() == keyphraseId) 110 && key.locale().equals(bcp47Locale) 111 && key.userIds().contains(userHandle)); 112 } 113 114 @Override getKeyphraseSoundModel(int keyphraseId, int userHandle, String bcp47Locale)115 public KeyphraseSoundModel getKeyphraseSoundModel(int keyphraseId, int userHandle, 116 String bcp47Locale) { 117 return mModelMap.entrySet() 118 .stream() 119 .filter((entry) -> (entry.getKey().keyphraseId() == keyphraseId) 120 && entry.getKey().locale().equals(bcp47Locale) 121 && entry.getKey().userIds().contains(userHandle)) 122 .findFirst() 123 .map((entry) -> entry.getValue()) 124 .orElse(null); 125 } 126 127 @Override getKeyphraseSoundModel(String keyphrase, int userHandle, String bcp47Locale)128 public KeyphraseSoundModel getKeyphraseSoundModel(String keyphrase, int userHandle, 129 String bcp47Locale) { 130 return mModelMap.entrySet() 131 .stream() 132 .filter((entry) -> (entry.getValue().getKeyphrases()[0].getText().equals(keyphrase) 133 && entry.getKey().locale().equals(bcp47Locale) 134 && entry.getKey().userIds().contains(userHandle))) 135 .findFirst() 136 .map((entry) -> entry.getValue()) 137 .orElse(null); 138 } 139 140 141 /** 142 * Dumps contents of database for dumpsys 143 */ dump(PrintWriter pw)144 public void dump(PrintWriter pw) { 145 pw.println("Using test enrollment database, with enrolled models:"); 146 pw.println(mModelMap); 147 } 148 } 149