1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.google.android.textclassifier; 18 19 import java.util.concurrent.atomic.AtomicBoolean; 20 21 /** 22 * Java wrapper for ActionsSuggestions native library interface. This library is used to suggest 23 * actions and replies in a given conversation. 24 * 25 * @hide 26 */ 27 public final class ActionsSuggestionsModel implements AutoCloseable { 28 private final AtomicBoolean isClosed = new AtomicBoolean(false); 29 30 static { 31 System.loadLibrary("textclassifier"); 32 } 33 34 private long actionsModelPtr; 35 36 /** 37 * Creates a new instance of Actions predictor, using the provided model image, given as a file 38 * descriptor. 39 */ ActionsSuggestionsModel(int fileDescriptor, byte[] serializedPreconditions)40 public ActionsSuggestionsModel(int fileDescriptor, byte[] serializedPreconditions) { 41 actionsModelPtr = nativeNewActionsModel(fileDescriptor, serializedPreconditions); 42 if (actionsModelPtr == 0L) { 43 throw new IllegalArgumentException("Couldn't initialize actions model from file descriptor."); 44 } 45 } 46 ActionsSuggestionsModel(int fileDescriptor)47 public ActionsSuggestionsModel(int fileDescriptor) { 48 this(fileDescriptor, /* serializedPreconditions= */ null); 49 } 50 51 /** 52 * Creates a new instance of Actions predictor, using the provided model image, given as a file 53 * path. 54 */ ActionsSuggestionsModel(String path, byte[] serializedPreconditions)55 public ActionsSuggestionsModel(String path, byte[] serializedPreconditions) { 56 actionsModelPtr = nativeNewActionsModelFromPath(path, serializedPreconditions); 57 if (actionsModelPtr == 0L) { 58 throw new IllegalArgumentException("Couldn't initialize actions model from given file."); 59 } 60 } 61 ActionsSuggestionsModel(String path)62 public ActionsSuggestionsModel(String path) { 63 this(path, /* serializedPreconditions= */ null); 64 } 65 66 /** Suggests actions / replies to the given conversation. */ suggestActions( Conversation conversation, ActionSuggestionOptions options, AnnotatorModel annotator)67 public ActionSuggestion[] suggestActions( 68 Conversation conversation, ActionSuggestionOptions options, AnnotatorModel annotator) { 69 return nativeSuggestActions( 70 actionsModelPtr, 71 conversation, 72 options, 73 (annotator != null ? annotator.getNativeAnnotator() : 0), 74 /* appContext= */ null, 75 /* deviceLocales= */ null, 76 /* generateAndroidIntents= */ false); 77 } 78 suggestActionsWithIntents( Conversation conversation, ActionSuggestionOptions options, Object appContext, String deviceLocales, AnnotatorModel annotator)79 public ActionSuggestion[] suggestActionsWithIntents( 80 Conversation conversation, 81 ActionSuggestionOptions options, 82 Object appContext, 83 String deviceLocales, 84 AnnotatorModel annotator) { 85 return nativeSuggestActions( 86 actionsModelPtr, 87 conversation, 88 options, 89 (annotator != null ? annotator.getNativeAnnotator() : 0), 90 appContext, 91 deviceLocales, 92 /* generateAndroidIntents= */ true); 93 } 94 95 /** Frees up the allocated memory. */ 96 @Override close()97 public void close() { 98 if (isClosed.compareAndSet(false, true)) { 99 nativeCloseActionsModel(actionsModelPtr); 100 actionsModelPtr = 0L; 101 } 102 } 103 104 @Override finalize()105 protected void finalize() throws Throwable { 106 try { 107 close(); 108 } finally { 109 super.finalize(); 110 } 111 } 112 113 /** Returns a comma separated list of locales supported by the model as BCP 47 tags. */ getLocales(int fd)114 public static String getLocales(int fd) { 115 return nativeGetLocales(fd); 116 } 117 118 /** Returns the version of the model. */ getVersion(int fd)119 public static int getVersion(int fd) { 120 return nativeGetVersion(fd); 121 } 122 123 /** Returns the name of the model. */ getName(int fd)124 public static String getName(int fd) { 125 return nativeGetName(fd); 126 } 127 128 /** Action suggestion that contains a response text and the type of the response. */ 129 public static final class ActionSuggestion { 130 private final String responseText; 131 private final String actionType; 132 private final float score; 133 private final NamedVariant[] entityData; 134 private final byte[] serializedEntityData; 135 private final RemoteActionTemplate[] remoteActionTemplates; 136 ActionSuggestion( String responseText, String actionType, float score, NamedVariant[] entityData, byte[] serializedEntityData, RemoteActionTemplate[] remoteActionTemplates)137 public ActionSuggestion( 138 String responseText, 139 String actionType, 140 float score, 141 NamedVariant[] entityData, 142 byte[] serializedEntityData, 143 RemoteActionTemplate[] remoteActionTemplates) { 144 this.responseText = responseText; 145 this.actionType = actionType; 146 this.score = score; 147 this.entityData = entityData; 148 this.serializedEntityData = serializedEntityData; 149 this.remoteActionTemplates = remoteActionTemplates; 150 } 151 getResponseText()152 public String getResponseText() { 153 return responseText; 154 } 155 getActionType()156 public String getActionType() { 157 return actionType; 158 } 159 160 /** Confidence score between 0 and 1 */ getScore()161 public float getScore() { 162 return score; 163 } 164 getEntityData()165 public NamedVariant[] getEntityData() { 166 return entityData; 167 } 168 getSerializedEntityData()169 public byte[] getSerializedEntityData() { 170 return serializedEntityData; 171 } 172 getRemoteActionTemplates()173 public RemoteActionTemplate[] getRemoteActionTemplates() { 174 return remoteActionTemplates; 175 } 176 } 177 178 /** Represents a single message in the conversation. */ 179 public static final class ConversationMessage { 180 private final int userId; 181 private final String text; 182 private final long referenceTimeMsUtc; 183 private final String referenceTimezone; 184 private final String detectedTextLanguageTags; 185 ConversationMessage( int userId, String text, long referenceTimeMsUtc, String referenceTimezone, String detectedTextLanguageTags)186 public ConversationMessage( 187 int userId, 188 String text, 189 long referenceTimeMsUtc, 190 String referenceTimezone, 191 String detectedTextLanguageTags) { 192 this.userId = userId; 193 this.text = text; 194 this.referenceTimeMsUtc = referenceTimeMsUtc; 195 this.referenceTimezone = referenceTimezone; 196 this.detectedTextLanguageTags = detectedTextLanguageTags; 197 } 198 199 /** The identifier of the sender */ getUserId()200 public int getUserId() { 201 return userId; 202 } 203 getText()204 public String getText() { 205 return text; 206 } 207 208 /** 209 * Return the reference time of the message, for example, it could be compose time or send time. 210 * {@code 0} means unspecified. 211 */ getReferenceTimeMsUtc()212 public long getReferenceTimeMsUtc() { 213 return referenceTimeMsUtc; 214 } 215 getReferenceTimezone()216 public String getReferenceTimezone() { 217 return referenceTimezone; 218 } 219 220 /** Returns a comma separated list of BCP 47 language tags. */ getDetectedTextLanguageTags()221 public String getDetectedTextLanguageTags() { 222 return detectedTextLanguageTags; 223 } 224 } 225 226 /** Represents conversation between multiple users. */ 227 public static final class Conversation { 228 public final ConversationMessage[] conversationMessages; 229 Conversation(ConversationMessage[] conversationMessages)230 public Conversation(ConversationMessage[] conversationMessages) { 231 this.conversationMessages = conversationMessages; 232 } 233 getConversationMessages()234 public ConversationMessage[] getConversationMessages() { 235 return conversationMessages; 236 } 237 } 238 239 /** Represents options for the SuggestActions call. */ 240 public static final class ActionSuggestionOptions { ActionSuggestionOptions()241 public ActionSuggestionOptions() {} 242 } 243 nativeNewActionsModel(int fd, byte[] serializedPreconditions)244 private static native long nativeNewActionsModel(int fd, byte[] serializedPreconditions); 245 nativeNewActionsModelFromPath( String path, byte[] preconditionsOverwrite)246 private static native long nativeNewActionsModelFromPath( 247 String path, byte[] preconditionsOverwrite); 248 nativeGetLocales(int fd)249 private static native String nativeGetLocales(int fd); 250 nativeGetVersion(int fd)251 private static native int nativeGetVersion(int fd); 252 nativeGetName(int fd)253 private static native String nativeGetName(int fd); 254 nativeSuggestActions( long context, Conversation conversation, ActionSuggestionOptions options, long annotatorPtr, Object appContext, String deviceLocales, boolean generateAndroidIntents)255 private native ActionSuggestion[] nativeSuggestActions( 256 long context, 257 Conversation conversation, 258 ActionSuggestionOptions options, 259 long annotatorPtr, 260 Object appContext, 261 String deviceLocales, 262 boolean generateAndroidIntents); 263 nativeCloseActionsModel(long ptr)264 private native void nativeCloseActionsModel(long ptr); 265 } 266