• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.google.android.textclassifier;
18 
19 import android.content.res.AssetFileDescriptor;
20 import java.util.concurrent.atomic.AtomicBoolean;
21 import javax.annotation.Nullable;
22 
23 /**
24  * Java wrapper for ActionsSuggestions native library interface. This library is used to suggest
25  * actions and replies in a given conversation.
26  *
27  * @hide
28  */
29 public final class ActionsSuggestionsModel implements AutoCloseable {
30   private final AtomicBoolean isClosed = new AtomicBoolean(false);
31 
32   static {
33     System.loadLibrary("textclassifier");
34   }
35 
36   private long actionsModelPtr;
37 
38   /**
39    * Creates a new instance of Actions predictor, using the provided model image, given as a file
40    * descriptor.
41    */
ActionsSuggestionsModel(int fileDescriptor, @Nullable byte[] serializedPreconditions)42   public ActionsSuggestionsModel(int fileDescriptor, @Nullable byte[] serializedPreconditions) {
43     actionsModelPtr = nativeNewActionsModel(fileDescriptor, serializedPreconditions);
44     if (actionsModelPtr == 0L) {
45       throw new IllegalArgumentException("Couldn't initialize actions model from file descriptor.");
46     }
47   }
48 
ActionsSuggestionsModel(int fileDescriptor)49   public ActionsSuggestionsModel(int fileDescriptor) {
50     this(fileDescriptor, /* serializedPreconditions= */ null);
51   }
52 
53   /**
54    * Creates a new instance of Actions predictor, using the provided model image, given as a file
55    * path.
56    */
ActionsSuggestionsModel(String path, @Nullable byte[] serializedPreconditions)57   public ActionsSuggestionsModel(String path, @Nullable byte[] serializedPreconditions) {
58     actionsModelPtr = nativeNewActionsModelFromPath(path, serializedPreconditions);
59     if (actionsModelPtr == 0L) {
60       throw new IllegalArgumentException("Couldn't initialize actions model from given file.");
61     }
62   }
63 
ActionsSuggestionsModel(String path)64   public ActionsSuggestionsModel(String path) {
65     this(path, /* serializedPreconditions= */ null);
66   }
67 
68   /**
69    * Creates a new instance of Actions predictor, using the provided model image, given as an {@link
70    * AssetFileDescriptor}).
71    */
ActionsSuggestionsModel( AssetFileDescriptor assetFileDescriptor, @Nullable byte[] serializedPreconditions)72   public ActionsSuggestionsModel(
73       AssetFileDescriptor assetFileDescriptor, @Nullable byte[] serializedPreconditions) {
74     actionsModelPtr =
75         nativeNewActionsModelWithOffset(
76             assetFileDescriptor.getParcelFileDescriptor().getFd(),
77             assetFileDescriptor.getStartOffset(),
78             assetFileDescriptor.getLength(),
79             serializedPreconditions);
80     if (actionsModelPtr == 0L) {
81       throw new IllegalArgumentException("Couldn't initialize actions model from file descriptor.");
82     }
83   }
84 
ActionsSuggestionsModel(AssetFileDescriptor assetFileDescriptor)85   public ActionsSuggestionsModel(AssetFileDescriptor assetFileDescriptor) {
86     this(assetFileDescriptor, /* serializedPreconditions= */ null);
87   }
88 
89   /** Suggests actions / replies to the given conversation. */
suggestActions( Conversation conversation, ActionSuggestionOptions options, AnnotatorModel annotator)90   public ActionSuggestions suggestActions(
91       Conversation conversation, ActionSuggestionOptions options, AnnotatorModel annotator) {
92     return nativeSuggestActions(
93         actionsModelPtr,
94         conversation,
95         options,
96         (annotator != null ? annotator.getNativeAnnotatorPointer() : 0),
97         /* appContext= */ null,
98         /* deviceLocales= */ null,
99         /* generateAndroidIntents= */ false);
100   }
101 
suggestActionsWithIntents( Conversation conversation, ActionSuggestionOptions options, Object appContext, String deviceLocales, AnnotatorModel annotator)102   public ActionSuggestions suggestActionsWithIntents(
103       Conversation conversation,
104       ActionSuggestionOptions options,
105       Object appContext,
106       String deviceLocales,
107       AnnotatorModel annotator) {
108     return nativeSuggestActions(
109         actionsModelPtr,
110         conversation,
111         options,
112         (annotator != null ? annotator.getNativeAnnotatorPointer() : 0),
113         appContext,
114         deviceLocales,
115         /* generateAndroidIntents= */ true);
116   }
117 
118   /** Frees up the allocated memory. */
119   @Override
close()120   public void close() {
121     if (isClosed.compareAndSet(false, true)) {
122       nativeCloseActionsModel(actionsModelPtr);
123       actionsModelPtr = 0L;
124     }
125   }
126 
127   @Override
finalize()128   protected void finalize() throws Throwable {
129     try {
130       close();
131     } finally {
132       super.finalize();
133     }
134   }
135 
136   /** Returns a comma separated list of locales supported by the model as BCP 47 tags. */
getLocales(int fd)137   public static String getLocales(int fd) {
138     return nativeGetLocales(fd);
139   }
140 
141   /** Returns a comma separated list of locales supported by the model as BCP 47 tags. */
getLocales(AssetFileDescriptor assetFileDescriptor)142   public static String getLocales(AssetFileDescriptor assetFileDescriptor) {
143     return nativeGetLocalesWithOffset(
144         assetFileDescriptor.getParcelFileDescriptor().getFd(),
145         assetFileDescriptor.getStartOffset(),
146         assetFileDescriptor.getLength());
147   }
148 
149   /** Returns the version of the model. */
getVersion(int fd)150   public static int getVersion(int fd) {
151     return nativeGetVersion(fd);
152   }
153 
154   /** Returns the version of the model. */
getVersion(AssetFileDescriptor assetFileDescriptor)155   public static int getVersion(AssetFileDescriptor assetFileDescriptor) {
156     return nativeGetVersionWithOffset(
157         assetFileDescriptor.getParcelFileDescriptor().getFd(),
158         assetFileDescriptor.getStartOffset(),
159         assetFileDescriptor.getLength());
160   }
161 
162   /** Returns the name of the model. */
getName(int fd)163   public static String getName(int fd) {
164     return nativeGetName(fd);
165   }
166 
167   /** Returns the name of the model. */
getName(AssetFileDescriptor assetFileDescriptor)168   public static String getName(AssetFileDescriptor assetFileDescriptor) {
169     return nativeGetNameWithOffset(
170         assetFileDescriptor.getParcelFileDescriptor().getFd(),
171         assetFileDescriptor.getStartOffset(),
172         assetFileDescriptor.getLength());
173   }
174 
175   /** Initializes conversation intent detection, passing the given serialized config to it. */
initializeConversationIntentDetection(byte[] serializedConfig)176   public void initializeConversationIntentDetection(byte[] serializedConfig) {
177     if (!nativeInitializeConversationIntentDetection(actionsModelPtr, serializedConfig)) {
178       throw new IllegalArgumentException("Couldn't initialize conversation intent detection");
179     }
180   }
181 
182   /** Represents a list of suggested actions of a given conversation. */
183   public static final class ActionSuggestions {
184     /** A list of suggested actionsm sorted by score descendingly. */
185     public final ActionSuggestion[] actionSuggestions;
186     /** Whether the input conversation is considered as sensitive. */
187     public final boolean isSensitive;
188 
ActionSuggestions(ActionSuggestion[] actionSuggestions, boolean isSensitive)189     public ActionSuggestions(ActionSuggestion[] actionSuggestions, boolean isSensitive) {
190       this.actionSuggestions = actionSuggestions;
191       this.isSensitive = isSensitive;
192     }
193   }
194 
195   /** Action suggestion that contains a response text and the type of the response. */
196   public static final class ActionSuggestion {
197     @Nullable private final String responseText;
198     private final String actionType;
199     private final float score;
200     @Nullable private final NamedVariant[] entityData;
201     @Nullable private final byte[] serializedEntityData;
202     @Nullable private final RemoteActionTemplate[] remoteActionTemplates;
203     @Nullable private final Slot[] slots;
204 
ActionSuggestion( @ullable String responseText, String actionType, float score, @Nullable NamedVariant[] entityData, @Nullable byte[] serializedEntityData, @Nullable RemoteActionTemplate[] remoteActionTemplates, @Nullable Slot[] slots)205     public ActionSuggestion(
206         @Nullable String responseText,
207         String actionType,
208         float score,
209         @Nullable NamedVariant[] entityData,
210         @Nullable byte[] serializedEntityData,
211         @Nullable RemoteActionTemplate[] remoteActionTemplates,
212         @Nullable Slot[] slots) {
213       this.responseText = responseText;
214       this.actionType = actionType;
215       this.score = score;
216       this.entityData = entityData;
217       this.serializedEntityData = serializedEntityData;
218       this.remoteActionTemplates = remoteActionTemplates;
219       this.slots = slots;
220     }
221 
222     @Nullable
getResponseText()223     public String getResponseText() {
224       return responseText;
225     }
226 
getActionType()227     public String getActionType() {
228       return actionType;
229     }
230 
231     /** Confidence score between 0 and 1 */
getScore()232     public float getScore() {
233       return score;
234     }
235 
236     @Nullable
getEntityData()237     public NamedVariant[] getEntityData() {
238       return entityData;
239     }
240 
241     @Nullable
getSerializedEntityData()242     public byte[] getSerializedEntityData() {
243       return serializedEntityData;
244     }
245 
246     @Nullable
getRemoteActionTemplates()247     public RemoteActionTemplate[] getRemoteActionTemplates() {
248       return remoteActionTemplates;
249     }
250 
251     @Nullable
getSlots()252     public Slot[] getSlots() {
253       return slots;
254     }
255   }
256 
257   /** Represents a single message in the conversation. */
258   public static final class ConversationMessage {
259     private final int userId;
260     @Nullable private final String text;
261     private final long referenceTimeMsUtc;
262     @Nullable private final String referenceTimezone;
263     @Nullable private final String detectedTextLanguageTags;
264 
ConversationMessage( int userId, @Nullable String text, long referenceTimeMsUtc, @Nullable String referenceTimezone, @Nullable String detectedTextLanguageTags)265     public ConversationMessage(
266         int userId,
267         @Nullable String text,
268         long referenceTimeMsUtc,
269         @Nullable String referenceTimezone,
270         @Nullable String detectedTextLanguageTags) {
271       this.userId = userId;
272       this.text = text;
273       this.referenceTimeMsUtc = referenceTimeMsUtc;
274       this.referenceTimezone = referenceTimezone;
275       this.detectedTextLanguageTags = detectedTextLanguageTags;
276     }
277 
278     /** The identifier of the sender */
getUserId()279     public int getUserId() {
280       return userId;
281     }
282 
283     @Nullable
getText()284     public String getText() {
285       return text;
286     }
287 
288     /**
289      * Return the reference time of the message, for example, it could be compose time or send time.
290      * {@code 0} means unspecified.
291      */
getReferenceTimeMsUtc()292     public long getReferenceTimeMsUtc() {
293       return referenceTimeMsUtc;
294     }
295 
296     @Nullable
getReferenceTimezone()297     public String getReferenceTimezone() {
298       return referenceTimezone;
299     }
300 
301     /** Returns a comma separated list of BCP 47 language tags. */
302     @Nullable
getDetectedTextLanguageTags()303     public String getDetectedTextLanguageTags() {
304       return detectedTextLanguageTags;
305     }
306   }
307 
308   /** Represents conversation between multiple users. */
309   public static final class Conversation {
310     public final ConversationMessage[] conversationMessages;
311 
Conversation(ConversationMessage[] conversationMessages)312     public Conversation(ConversationMessage[] conversationMessages) {
313       this.conversationMessages = conversationMessages;
314     }
315 
getConversationMessages()316     public ConversationMessage[] getConversationMessages() {
317       return conversationMessages;
318     }
319   }
320 
321   /** Represents options for the SuggestActions call. */
322   public static final class ActionSuggestionOptions {
ActionSuggestionOptions()323     public ActionSuggestionOptions() {}
324   }
325 
326   /** Represents a slot for an {@link ActionSuggestion}. */
327   public static final class Slot {
328 
329     public final String type;
330     public final int messageIndex;
331     public final int startIndex;
332     public final int endIndex;
333     public final float confidenceScore;
334 
Slot( String type, int messageIndex, int startIndex, int endIndex, float confidenceScore)335     public Slot(
336         String type, int messageIndex, int startIndex, int endIndex, float confidenceScore) {
337       this.type = type;
338       this.messageIndex = messageIndex;
339       this.startIndex = startIndex;
340       this.endIndex = endIndex;
341       this.confidenceScore = confidenceScore;
342     }
343   }
344 
345   /**
346    * Retrieves the pointer to the native object. Note: Need to keep the {@code
347    * ActionsSuggestionsModel} alive as long as the pointer is used.
348    */
getNativeModelPointer()349   long getNativeModelPointer() {
350     return nativeGetNativeModelPtr(actionsModelPtr);
351   }
352 
nativeNewActionsModel(int fd, byte[] serializedPreconditions)353   private static native long nativeNewActionsModel(int fd, byte[] serializedPreconditions);
354 
nativeNewActionsModelFromPath( String path, byte[] preconditionsOverwrite)355   private static native long nativeNewActionsModelFromPath(
356       String path, byte[] preconditionsOverwrite);
357 
nativeNewActionsModelWithOffset( int fd, long offset, long size, byte[] preconditionsOverwrite)358   private static native long nativeNewActionsModelWithOffset(
359       int fd, long offset, long size, byte[] preconditionsOverwrite);
360 
nativeInitializeConversationIntentDetection( long actionsModelPtr, byte[] serializedConfig)361   private native boolean nativeInitializeConversationIntentDetection(
362       long actionsModelPtr, byte[] serializedConfig);
363 
nativeGetLocales(int fd)364   private static native String nativeGetLocales(int fd);
365 
nativeGetLocalesWithOffset(int fd, long offset, long size)366   private static native String nativeGetLocalesWithOffset(int fd, long offset, long size);
367 
nativeGetVersion(int fd)368   private static native int nativeGetVersion(int fd);
369 
nativeGetVersionWithOffset(int fd, long offset, long size)370   private static native int nativeGetVersionWithOffset(int fd, long offset, long size);
371 
nativeGetName(int fd)372   private static native String nativeGetName(int fd);
373 
nativeGetNameWithOffset(int fd, long offset, long size)374   private static native String nativeGetNameWithOffset(int fd, long offset, long size);
375 
nativeSuggestActions( long context, Conversation conversation, ActionSuggestionOptions options, long annotatorPtr, Object appContext, String deviceLocales, boolean generateAndroidIntents)376   private native ActionSuggestions nativeSuggestActions(
377       long context,
378       Conversation conversation,
379       ActionSuggestionOptions options,
380       long annotatorPtr,
381       Object appContext,
382       String deviceLocales,
383       boolean generateAndroidIntents);
384 
nativeCloseActionsModel(long ptr)385   private native void nativeCloseActionsModel(long ptr);
386 
nativeGetNativeModelPtr(long context)387   private native long nativeGetNativeModelPtr(long context);
388 }
389