1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "UnwantedInteractionBlocker"
18 #include "UnwantedInteractionBlocker.h"
19 
20 #include <android-base/stringprintf.h>
21 #include <input/PrintTools.h>
22 #include <inttypes.h>
23 #include <linux/input-event-codes.h>
24 #include <linux/input.h>
25 #include <server_configurable_flags/get_flags.h>
26 
27 #include "ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter.h"
28 #include "ui/events/ozone/evdev/touch_filter/palm_model/onedevice_train_palm_detection_filter_model.h"
29 
30 using android::base::StringPrintf;
31 
32 /**
33  * This type is declared here to ensure consistency between the instantiated type (used in the
34  * constructor via std::make_unique) and the cast-to type (used in PalmRejector::dump() with
35  * static_cast). Due to the lack of rtti support, dynamic_cast is not available, so this can't be
36  * checked at runtime to avoid undefined behaviour.
37  */
38 using PalmFilterImplementation = ::ui::NeuralStylusPalmDetectionFilter;
39 
40 namespace android {
41 
42 /**
43  * Log detailed debug messages about each inbound motion event notification to the blocker.
44  * Enable this via "adb shell setprop log.tag.UnwantedInteractionBlockerInboundMotion DEBUG"
45  * (requires restart)
46  */
47 const bool DEBUG_INBOUND_MOTION =
48         __android_log_is_loggable(ANDROID_LOG_DEBUG, LOG_TAG "InboundMotion", ANDROID_LOG_INFO);
49 
50 /**
51  * Log detailed debug messages about each outbound motion event processed by the blocker.
52  * Enable this via "adb shell setprop log.tag.UnwantedInteractionBlockerOutboundMotion DEBUG"
53  * (requires restart)
54  */
55 const bool DEBUG_OUTBOUND_MOTION =
56         __android_log_is_loggable(ANDROID_LOG_DEBUG, LOG_TAG "OutboundMotion", ANDROID_LOG_INFO);
57 
58 /**
59  * Log the data sent to the model and received back from the model.
60  * Enable this via "adb shell setprop log.tag.UnwantedInteractionBlockerModel DEBUG"
61  * (requires restart)
62  */
63 const bool DEBUG_MODEL =
64         __android_log_is_loggable(ANDROID_LOG_DEBUG, LOG_TAG "Model", ANDROID_LOG_INFO);
65 
66 // Category (=namespace) name for the input settings that are applied at boot time
67 static const char* INPUT_NATIVE_BOOT = "input_native_boot";
68 /**
69  * Feature flag name. This flag determines whether palm rejection is enabled. To enable, specify
70  * 'true' (not case sensitive) or '1'. To disable, specify any other value.
71  */
72 static const char* PALM_REJECTION_ENABLED = "palm_rejection_enabled";
73 
toLower(std::string s)74 static std::string toLower(std::string s) {
75     std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { return std::tolower(c); });
76     return s;
77 }
78 
isFromTouchscreen(int32_t source)79 static bool isFromTouchscreen(int32_t source) {
80     return isFromSource(source, AINPUT_SOURCE_TOUCHSCREEN);
81 }
82 
toChromeTimestamp(nsecs_t eventTime)83 static ::base::TimeTicks toChromeTimestamp(nsecs_t eventTime) {
84     return ::base::TimeTicks::UnixEpoch() + ::base::TimeDelta::FromNanosecondsD(eventTime);
85 }
86 
87 /**
88  * Return true if palm rejection is enabled via the server configurable flags. Return false
89  * otherwise.
90  */
isPalmRejectionEnabled()91 static bool isPalmRejectionEnabled() {
92     std::string value = toLower(
93             server_configurable_flags::GetServerConfigurableFlag(INPUT_NATIVE_BOOT,
94                                                                  PALM_REJECTION_ENABLED, "0"));
95     if (value == "1") {
96         return true;
97     }
98     return false;
99 }
100 
getLinuxToolCode(int toolType)101 static int getLinuxToolCode(int toolType) {
102     if (toolType == AMOTION_EVENT_TOOL_TYPE_STYLUS) {
103         return BTN_TOOL_PEN;
104     }
105     if (toolType == AMOTION_EVENT_TOOL_TYPE_FINGER) {
106         return BTN_TOOL_FINGER;
107     }
108     ALOGW("Got tool type %" PRId32 ", converting to BTN_TOOL_FINGER", toolType);
109     return BTN_TOOL_FINGER;
110 }
111 
getActionUpForPointerId(const NotifyMotionArgs & args,int32_t pointerId)112 static int32_t getActionUpForPointerId(const NotifyMotionArgs& args, int32_t pointerId) {
113     for (size_t i = 0; i < args.pointerCount; i++) {
114         if (pointerId == args.pointerProperties[i].id) {
115             return AMOTION_EVENT_ACTION_POINTER_UP |
116                     (i << AMOTION_EVENT_ACTION_POINTER_INDEX_SHIFT);
117         }
118     }
119     LOG_ALWAYS_FATAL("Can't find pointerId %" PRId32 " in %s", pointerId, args.dump().c_str());
120 }
121 
122 /**
123  * Find the action for individual pointer at the given pointer index.
124  * This is always equal to MotionEvent::getActionMasked, except for
125  * POINTER_UP or POINTER_DOWN events. For example, in a POINTER_UP event, the action for
126  * the active pointer is ACTION_POINTER_UP, while the action for the other pointers is ACTION_MOVE.
127  */
resolveActionForPointer(uint8_t pointerIndex,int32_t action)128 static int32_t resolveActionForPointer(uint8_t pointerIndex, int32_t action) {
129     const int32_t actionMasked = MotionEvent::getActionMasked(action);
130     if (actionMasked != AMOTION_EVENT_ACTION_POINTER_DOWN &&
131         actionMasked != AMOTION_EVENT_ACTION_POINTER_UP) {
132         return actionMasked;
133     }
134     // This is a POINTER_DOWN or POINTER_UP event
135     const uint8_t actionIndex = MotionEvent::getActionIndex(action);
136     if (pointerIndex == actionIndex) {
137         return actionMasked;
138     }
139     // When POINTER_DOWN or POINTER_UP happens, it's actually a MOVE for all of the other
140     // pointers
141     return AMOTION_EVENT_ACTION_MOVE;
142 }
143 
removePointerIds(const NotifyMotionArgs & args,const std::set<int32_t> & pointerIds)144 NotifyMotionArgs removePointerIds(const NotifyMotionArgs& args,
145                                   const std::set<int32_t>& pointerIds) {
146     const uint8_t actionIndex = MotionEvent::getActionIndex(args.action);
147     const int32_t actionMasked = MotionEvent::getActionMasked(args.action);
148     const bool isPointerUpOrDownAction = actionMasked == AMOTION_EVENT_ACTION_POINTER_DOWN ||
149             actionMasked == AMOTION_EVENT_ACTION_POINTER_UP;
150 
151     NotifyMotionArgs newArgs{args};
152     newArgs.pointerCount = 0;
153     int32_t newActionIndex = 0;
154     for (uint32_t i = 0; i < args.pointerCount; i++) {
155         const int32_t pointerId = args.pointerProperties[i].id;
156         if (pointerIds.find(pointerId) != pointerIds.end()) {
157             // skip this pointer
158             if (isPointerUpOrDownAction && i == actionIndex) {
159                 // The active pointer is being removed, so the action is no longer valid.
160                 // Set the action to 'UNKNOWN' here. The caller is responsible for updating this
161                 // action later to a proper value.
162                 newArgs.action = ACTION_UNKNOWN;
163             }
164             continue;
165         }
166         newArgs.pointerProperties[newArgs.pointerCount].copyFrom(args.pointerProperties[i]);
167         newArgs.pointerCoords[newArgs.pointerCount].copyFrom(args.pointerCoords[i]);
168         if (i == actionIndex) {
169             newActionIndex = newArgs.pointerCount;
170         }
171         newArgs.pointerCount++;
172     }
173     // Update POINTER_DOWN or POINTER_UP actions
174     if (isPointerUpOrDownAction && newArgs.action != ACTION_UNKNOWN) {
175         newArgs.action =
176                 actionMasked | (newActionIndex << AMOTION_EVENT_ACTION_POINTER_INDEX_SHIFT);
177         // Convert POINTER_DOWN and POINTER_UP to DOWN and UP if there's only 1 pointer remaining
178         if (newArgs.pointerCount == 1) {
179             if (actionMasked == AMOTION_EVENT_ACTION_POINTER_DOWN) {
180                 newArgs.action = AMOTION_EVENT_ACTION_DOWN;
181             } else if (actionMasked == AMOTION_EVENT_ACTION_POINTER_UP) {
182                 newArgs.action = AMOTION_EVENT_ACTION_UP;
183             }
184         }
185     }
186     return newArgs;
187 }
188 
189 /**
190  * Remove stylus pointers from the provided NotifyMotionArgs.
191  *
192  * Return NotifyMotionArgs where the stylus pointers have been removed.
193  * If this results in removal of the active pointer, then return nullopt.
194  */
removeStylusPointerIds(const NotifyMotionArgs & args)195 static std::optional<NotifyMotionArgs> removeStylusPointerIds(const NotifyMotionArgs& args) {
196     std::set<int32_t> stylusPointerIds;
197     for (uint32_t i = 0; i < args.pointerCount; i++) {
198         if (args.pointerProperties[i].toolType == AMOTION_EVENT_TOOL_TYPE_STYLUS) {
199             stylusPointerIds.insert(args.pointerProperties[i].id);
200         }
201     }
202     NotifyMotionArgs withoutStylusPointers = removePointerIds(args, stylusPointerIds);
203     if (withoutStylusPointers.pointerCount == 0 || withoutStylusPointers.action == ACTION_UNKNOWN) {
204         return std::nullopt;
205     }
206     return withoutStylusPointers;
207 }
208 
createPalmFilterDeviceInfo(const InputDeviceInfo & deviceInfo)209 std::optional<AndroidPalmFilterDeviceInfo> createPalmFilterDeviceInfo(
210         const InputDeviceInfo& deviceInfo) {
211     if (!isFromTouchscreen(deviceInfo.getSources())) {
212         return std::nullopt;
213     }
214     AndroidPalmFilterDeviceInfo out;
215     const InputDeviceInfo::MotionRange* axisX =
216             deviceInfo.getMotionRange(AMOTION_EVENT_AXIS_X, AINPUT_SOURCE_TOUCHSCREEN);
217     if (axisX != nullptr) {
218         out.max_x = axisX->max;
219         out.x_res = axisX->resolution;
220     } else {
221         ALOGW("Palm rejection is disabled for %s because AXIS_X is not supported",
222               deviceInfo.getDisplayName().c_str());
223         return std::nullopt;
224     }
225     const InputDeviceInfo::MotionRange* axisY =
226             deviceInfo.getMotionRange(AMOTION_EVENT_AXIS_Y, AINPUT_SOURCE_TOUCHSCREEN);
227     if (axisY != nullptr) {
228         out.max_y = axisY->max;
229         out.y_res = axisY->resolution;
230     } else {
231         ALOGW("Palm rejection is disabled for %s because AXIS_Y is not supported",
232               deviceInfo.getDisplayName().c_str());
233         return std::nullopt;
234     }
235     const InputDeviceInfo::MotionRange* axisMajor =
236             deviceInfo.getMotionRange(AMOTION_EVENT_AXIS_TOUCH_MAJOR, AINPUT_SOURCE_TOUCHSCREEN);
237     if (axisMajor != nullptr) {
238         out.major_radius_res = axisMajor->resolution;
239         out.touch_major_res = axisMajor->resolution;
240     } else {
241         return std::nullopt;
242     }
243     const InputDeviceInfo::MotionRange* axisMinor =
244             deviceInfo.getMotionRange(AMOTION_EVENT_AXIS_TOUCH_MINOR, AINPUT_SOURCE_TOUCHSCREEN);
245     if (axisMinor != nullptr) {
246         out.minor_radius_res = axisMinor->resolution;
247         out.touch_minor_res = axisMinor->resolution;
248         out.minor_radius_supported = true;
249     } else {
250         out.minor_radius_supported = false;
251     }
252 
253     return out;
254 }
255 
256 /**
257  * Synthesize CANCEL events for any new pointers that should be canceled, while removing pointers
258  * that have already been canceled.
259  * The flow of the function is as follows:
260  * 1. Remove all already canceled pointers
261  * 2. Cancel all newly suppressed pointers
262  * 3. Decide what to do with the current event : keep it, or drop it
263  * The pointers can never be "unsuppressed": once a pointer is canceled, it will never become valid.
264  */
cancelSuppressedPointers(const NotifyMotionArgs & args,const std::set<int32_t> & oldSuppressedPointerIds,const std::set<int32_t> & newSuppressedPointerIds)265 std::vector<NotifyMotionArgs> cancelSuppressedPointers(
266         const NotifyMotionArgs& args, const std::set<int32_t>& oldSuppressedPointerIds,
267         const std::set<int32_t>& newSuppressedPointerIds) {
268     LOG_ALWAYS_FATAL_IF(args.pointerCount == 0, "0 pointers in %s", args.dump().c_str());
269 
270     // First, let's remove the old suppressed pointers. They've already been canceled previously.
271     NotifyMotionArgs oldArgs = removePointerIds(args, oldSuppressedPointerIds);
272 
273     // Cancel any newly suppressed pointers.
274     std::vector<NotifyMotionArgs> out;
275     const int32_t activePointerId =
276             args.pointerProperties[MotionEvent::getActionIndex(args.action)].id;
277     const int32_t actionMasked = MotionEvent::getActionMasked(args.action);
278     // We will iteratively remove pointers from 'removedArgs'.
279     NotifyMotionArgs removedArgs{oldArgs};
280     for (uint32_t i = 0; i < oldArgs.pointerCount; i++) {
281         const int32_t pointerId = oldArgs.pointerProperties[i].id;
282         if (newSuppressedPointerIds.find(pointerId) == newSuppressedPointerIds.end()) {
283             // This is a pointer that should not be canceled. Move on.
284             continue;
285         }
286         if (pointerId == activePointerId && actionMasked == AMOTION_EVENT_ACTION_POINTER_DOWN) {
287             // Remove this pointer, but don't cancel it. We'll just not send the POINTER_DOWN event
288             removedArgs = removePointerIds(removedArgs, {pointerId});
289             continue;
290         }
291 
292         if (removedArgs.pointerCount == 1) {
293             // We are about to remove the last pointer, which means there will be no more gesture
294             // remaining. This is identical to canceling all pointers, so just send a single CANCEL
295             // event, without any of the preceding POINTER_UP with FLAG_CANCELED events.
296             oldArgs.flags |= AMOTION_EVENT_FLAG_CANCELED;
297             oldArgs.action = AMOTION_EVENT_ACTION_CANCEL;
298             return {oldArgs};
299         }
300         // Cancel the current pointer
301         out.push_back(removedArgs);
302         out.back().flags |= AMOTION_EVENT_FLAG_CANCELED;
303         out.back().action = getActionUpForPointerId(out.back(), pointerId);
304 
305         // Remove the newly canceled pointer from the args
306         removedArgs = removePointerIds(removedArgs, {pointerId});
307     }
308 
309     // Now 'removedArgs' contains only pointers that are valid.
310     if (removedArgs.pointerCount <= 0 || removedArgs.action == ACTION_UNKNOWN) {
311         return out;
312     }
313     out.push_back(removedArgs);
314     return out;
315 }
316 
UnwantedInteractionBlocker(InputListenerInterface & listener)317 UnwantedInteractionBlocker::UnwantedInteractionBlocker(InputListenerInterface& listener)
318       : UnwantedInteractionBlocker(listener, isPalmRejectionEnabled()){};
319 
UnwantedInteractionBlocker(InputListenerInterface & listener,bool enablePalmRejection)320 UnwantedInteractionBlocker::UnwantedInteractionBlocker(InputListenerInterface& listener,
321                                                        bool enablePalmRejection)
322       : mQueuedListener(listener), mEnablePalmRejection(enablePalmRejection) {}
323 
notifyConfigurationChanged(const NotifyConfigurationChangedArgs * args)324 void UnwantedInteractionBlocker::notifyConfigurationChanged(
325         const NotifyConfigurationChangedArgs* args) {
326     mQueuedListener.notifyConfigurationChanged(args);
327     mQueuedListener.flush();
328 }
329 
notifyKey(const NotifyKeyArgs * args)330 void UnwantedInteractionBlocker::notifyKey(const NotifyKeyArgs* args) {
331     mQueuedListener.notifyKey(args);
332     mQueuedListener.flush();
333 }
334 
notifyMotion(const NotifyMotionArgs * args)335 void UnwantedInteractionBlocker::notifyMotion(const NotifyMotionArgs* args) {
336     ALOGD_IF(DEBUG_INBOUND_MOTION, "%s: %s", __func__, args->dump().c_str());
337     { // acquire lock
338         std::scoped_lock lock(mLock);
339         const std::vector<NotifyMotionArgs> processedArgs =
340                 mPreferStylusOverTouchBlocker.processMotion(*args);
341         for (const NotifyMotionArgs& loopArgs : processedArgs) {
342             notifyMotionLocked(&loopArgs);
343         }
344     } // release lock
345 
346     // Call out to the next stage without holding the lock
347     mQueuedListener.flush();
348 }
349 
enqueueOutboundMotionLocked(const NotifyMotionArgs & args)350 void UnwantedInteractionBlocker::enqueueOutboundMotionLocked(const NotifyMotionArgs& args) {
351     ALOGD_IF(DEBUG_OUTBOUND_MOTION, "%s: %s", __func__, args.dump().c_str());
352     mQueuedListener.notifyMotion(&args);
353 }
354 
notifyMotionLocked(const NotifyMotionArgs * args)355 void UnwantedInteractionBlocker::notifyMotionLocked(const NotifyMotionArgs* args) {
356     auto it = mPalmRejectors.find(args->deviceId);
357     const bool sendToPalmRejector = it != mPalmRejectors.end() && isFromTouchscreen(args->source);
358     if (!sendToPalmRejector) {
359         enqueueOutboundMotionLocked(*args);
360         return;
361     }
362 
363     std::vector<NotifyMotionArgs> processedArgs = it->second.processMotion(*args);
364     for (const NotifyMotionArgs& loopArgs : processedArgs) {
365         enqueueOutboundMotionLocked(loopArgs);
366     }
367 }
368 
notifySwitch(const NotifySwitchArgs * args)369 void UnwantedInteractionBlocker::notifySwitch(const NotifySwitchArgs* args) {
370     mQueuedListener.notifySwitch(args);
371     mQueuedListener.flush();
372 }
373 
notifySensor(const NotifySensorArgs * args)374 void UnwantedInteractionBlocker::notifySensor(const NotifySensorArgs* args) {
375     mQueuedListener.notifySensor(args);
376     mQueuedListener.flush();
377 }
378 
notifyVibratorState(const NotifyVibratorStateArgs * args)379 void UnwantedInteractionBlocker::notifyVibratorState(const NotifyVibratorStateArgs* args) {
380     mQueuedListener.notifyVibratorState(args);
381     mQueuedListener.flush();
382 }
notifyDeviceReset(const NotifyDeviceResetArgs * args)383 void UnwantedInteractionBlocker::notifyDeviceReset(const NotifyDeviceResetArgs* args) {
384     { // acquire lock
385         std::scoped_lock lock(mLock);
386         auto it = mPalmRejectors.find(args->deviceId);
387         if (it != mPalmRejectors.end()) {
388             AndroidPalmFilterDeviceInfo info = it->second.getPalmFilterDeviceInfo();
389             // Re-create the object instead of resetting it
390             mPalmRejectors.erase(it);
391             mPalmRejectors.emplace(args->deviceId, info);
392         }
393         mQueuedListener.notifyDeviceReset(args);
394         mPreferStylusOverTouchBlocker.notifyDeviceReset(*args);
395     } // release lock
396     // Send events to the next stage without holding the lock
397     mQueuedListener.flush();
398 }
399 
notifyPointerCaptureChanged(const NotifyPointerCaptureChangedArgs * args)400 void UnwantedInteractionBlocker::notifyPointerCaptureChanged(
401         const NotifyPointerCaptureChangedArgs* args) {
402     mQueuedListener.notifyPointerCaptureChanged(args);
403     mQueuedListener.flush();
404 }
405 
notifyInputDevicesChanged(const std::vector<InputDeviceInfo> & inputDevices)406 void UnwantedInteractionBlocker::notifyInputDevicesChanged(
407         const std::vector<InputDeviceInfo>& inputDevices) {
408     std::scoped_lock lock(mLock);
409     if (!mEnablePalmRejection) {
410         // Palm rejection is disabled. Don't create any palm rejector objects.
411         return;
412     }
413 
414     // Let's see which of the existing devices didn't change, so that we can keep them
415     // and prevent event stream disruption
416     std::set<int32_t /*deviceId*/> devicesToKeep;
417     for (const InputDeviceInfo& device : inputDevices) {
418         std::optional<AndroidPalmFilterDeviceInfo> info = createPalmFilterDeviceInfo(device);
419         if (!info) {
420             continue;
421         }
422 
423         auto [it, emplaced] = mPalmRejectors.try_emplace(device.getId(), *info);
424         if (!emplaced && *info != it->second.getPalmFilterDeviceInfo()) {
425             // Re-create the PalmRejector because the device info has changed.
426             mPalmRejectors.erase(it);
427             mPalmRejectors.emplace(device.getId(), *info);
428         }
429         devicesToKeep.insert(device.getId());
430     }
431     // Delete all devices that we don't need to keep
432     std::erase_if(mPalmRejectors, [&devicesToKeep](const auto& item) {
433         auto const& [deviceId, _] = item;
434         return devicesToKeep.find(deviceId) == devicesToKeep.end();
435     });
436     mPreferStylusOverTouchBlocker.notifyInputDevicesChanged(inputDevices);
437 }
438 
dump(std::string & dump)439 void UnwantedInteractionBlocker::dump(std::string& dump) {
440     std::scoped_lock lock(mLock);
441     dump += "UnwantedInteractionBlocker:\n";
442     dump += "  mPreferStylusOverTouchBlocker:\n";
443     dump += addLinePrefix(mPreferStylusOverTouchBlocker.dump(), "    ");
444     dump += StringPrintf("  mEnablePalmRejection: %s\n",
445                          std::to_string(mEnablePalmRejection).c_str());
446     dump += StringPrintf("  isPalmRejectionEnabled (flag value): %s\n",
447                          std::to_string(isPalmRejectionEnabled()).c_str());
448     dump += mPalmRejectors.empty() ? "  mPalmRejectors: None\n" : "  mPalmRejectors:\n";
449     for (const auto& [deviceId, palmRejector] : mPalmRejectors) {
450         dump += StringPrintf("    deviceId = %" PRId32 ":\n", deviceId);
451         dump += addLinePrefix(palmRejector.dump(), "      ");
452     }
453 }
454 
monitor()455 void UnwantedInteractionBlocker::monitor() {
456     std::scoped_lock lock(mLock);
457 }
458 
~UnwantedInteractionBlocker()459 UnwantedInteractionBlocker::~UnwantedInteractionBlocker() {}
460 
update(const NotifyMotionArgs & args)461 void SlotState::update(const NotifyMotionArgs& args) {
462     for (size_t i = 0; i < args.pointerCount; i++) {
463         const int32_t pointerId = args.pointerProperties[i].id;
464         const int32_t resolvedAction = resolveActionForPointer(i, args.action);
465         processPointerId(pointerId, resolvedAction);
466     }
467 }
468 
findUnusedSlot() const469 size_t SlotState::findUnusedSlot() const {
470     size_t unusedSlot = 0;
471     // Since the collection is ordered, we can rely on the in-order traversal
472     for (const auto& [slot, trackingId] : mPointerIdsBySlot) {
473         if (unusedSlot != slot) {
474             break;
475         }
476         unusedSlot++;
477     }
478     return unusedSlot;
479 }
480 
processPointerId(int pointerId,int32_t actionMasked)481 void SlotState::processPointerId(int pointerId, int32_t actionMasked) {
482     switch (MotionEvent::getActionMasked(actionMasked)) {
483         case AMOTION_EVENT_ACTION_DOWN:
484         case AMOTION_EVENT_ACTION_POINTER_DOWN:
485         case AMOTION_EVENT_ACTION_HOVER_ENTER: {
486             // New pointer going down
487             size_t newSlot = findUnusedSlot();
488             mPointerIdsBySlot[newSlot] = pointerId;
489             mSlotsByPointerId[pointerId] = newSlot;
490             return;
491         }
492         case AMOTION_EVENT_ACTION_MOVE:
493         case AMOTION_EVENT_ACTION_HOVER_MOVE: {
494             return;
495         }
496         case AMOTION_EVENT_ACTION_CANCEL:
497         case AMOTION_EVENT_ACTION_POINTER_UP:
498         case AMOTION_EVENT_ACTION_UP:
499         case AMOTION_EVENT_ACTION_HOVER_EXIT: {
500             auto it = mSlotsByPointerId.find(pointerId);
501             LOG_ALWAYS_FATAL_IF(it == mSlotsByPointerId.end());
502             size_t slot = it->second;
503             // Erase this pointer from both collections
504             mPointerIdsBySlot.erase(slot);
505             mSlotsByPointerId.erase(pointerId);
506             return;
507         }
508     }
509     LOG_ALWAYS_FATAL("Unhandled action : %s", MotionEvent::actionToString(actionMasked).c_str());
510     return;
511 }
512 
getSlotForPointerId(int32_t pointerId) const513 std::optional<size_t> SlotState::getSlotForPointerId(int32_t pointerId) const {
514     auto it = mSlotsByPointerId.find(pointerId);
515     if (it == mSlotsByPointerId.end()) {
516         return std::nullopt;
517     }
518     return it->second;
519 }
520 
dump() const521 std::string SlotState::dump() const {
522     std::string out = "mSlotsByPointerId:\n";
523     out += addLinePrefix(dumpMap(mSlotsByPointerId), "  ") + "\n";
524     out += "mPointerIdsBySlot:\n";
525     out += addLinePrefix(dumpMap(mPointerIdsBySlot), "  ") + "\n";
526     return out;
527 }
528 
529 class AndroidPalmRejectionModel : public ::ui::OneDeviceTrainNeuralStylusPalmDetectionFilterModel {
530 public:
AndroidPalmRejectionModel()531     AndroidPalmRejectionModel()
532           : ::ui::OneDeviceTrainNeuralStylusPalmDetectionFilterModel(/*default version*/ "",
533                                                                      std::vector<float>()) {
534         config_.resample_period = ::ui::kResamplePeriod;
535     }
536 };
537 
PalmRejector(const AndroidPalmFilterDeviceInfo & info,std::unique_ptr<::ui::PalmDetectionFilter> filter)538 PalmRejector::PalmRejector(const AndroidPalmFilterDeviceInfo& info,
539                            std::unique_ptr<::ui::PalmDetectionFilter> filter)
540       : mSharedPalmState(std::make_unique<::ui::SharedPalmDetectionFilterState>()),
541         mDeviceInfo(info),
542         mPalmDetectionFilter(std::move(filter)) {
543     if (mPalmDetectionFilter != nullptr) {
544         // This path is used for testing. Non-testing invocations should let this constructor
545         // create a real PalmDetectionFilter
546         return;
547     }
548     std::unique_ptr<::ui::NeuralStylusPalmDetectionFilterModel> model =
549             std::make_unique<AndroidPalmRejectionModel>();
550     mPalmDetectionFilter = std::make_unique<PalmFilterImplementation>(mDeviceInfo, std::move(model),
551                                                                       mSharedPalmState.get());
552 }
553 
getTouches(const NotifyMotionArgs & args,const AndroidPalmFilterDeviceInfo & deviceInfo,const SlotState & oldSlotState,const SlotState & newSlotState)554 std::vector<::ui::InProgressTouchEvdev> getTouches(const NotifyMotionArgs& args,
555                                                    const AndroidPalmFilterDeviceInfo& deviceInfo,
556                                                    const SlotState& oldSlotState,
557                                                    const SlotState& newSlotState) {
558     std::vector<::ui::InProgressTouchEvdev> touches;
559 
560     for (size_t i = 0; i < args.pointerCount; i++) {
561         const int32_t pointerId = args.pointerProperties[i].id;
562         touches.emplace_back(::ui::InProgressTouchEvdev());
563         touches.back().major = args.pointerCoords[i].getAxisValue(AMOTION_EVENT_AXIS_TOUCH_MAJOR);
564         touches.back().minor = args.pointerCoords[i].getAxisValue(AMOTION_EVENT_AXIS_TOUCH_MINOR);
565         // The field 'tool_type' is not used for palm rejection
566 
567         // Whether there is new information for the touch.
568         touches.back().altered = true;
569 
570         // Whether the touch was cancelled. Touch events should be ignored till a
571         // new touch is initiated.
572         touches.back().was_cancelled = false;
573 
574         // Whether the touch is going to be canceled.
575         touches.back().cancelled = false;
576 
577         // Whether the touch is delayed at first appearance. Will not be reported yet.
578         touches.back().delayed = false;
579 
580         // Whether the touch was delayed before.
581         touches.back().was_delayed = false;
582 
583         // Whether the touch is held until end or no longer held.
584         touches.back().held = false;
585 
586         // Whether this touch was held before being sent.
587         touches.back().was_held = false;
588 
589         const int32_t resolvedAction = resolveActionForPointer(i, args.action);
590         const bool isDown = resolvedAction == AMOTION_EVENT_ACTION_POINTER_DOWN ||
591                 resolvedAction == AMOTION_EVENT_ACTION_DOWN;
592         touches.back().was_touching = !isDown;
593 
594         const bool isUpOrCancel = resolvedAction == AMOTION_EVENT_ACTION_CANCEL ||
595                 resolvedAction == AMOTION_EVENT_ACTION_UP ||
596                 resolvedAction == AMOTION_EVENT_ACTION_POINTER_UP;
597 
598         touches.back().x = args.pointerCoords[i].getAxisValue(AMOTION_EVENT_AXIS_X);
599         touches.back().y = args.pointerCoords[i].getAxisValue(AMOTION_EVENT_AXIS_Y);
600 
601         std::optional<size_t> slot = newSlotState.getSlotForPointerId(pointerId);
602         if (!slot) {
603             slot = oldSlotState.getSlotForPointerId(pointerId);
604         }
605         LOG_ALWAYS_FATAL_IF(!slot, "Could not find slot for pointer %d", pointerId);
606         touches.back().slot = *slot;
607         touches.back().tracking_id = (!isUpOrCancel) ? pointerId : -1;
608         touches.back().touching = !isUpOrCancel;
609 
610         // The fields 'radius_x' and 'radius_x' are not used for palm rejection
611         touches.back().pressure = args.pointerCoords[i].getAxisValue(AMOTION_EVENT_AXIS_PRESSURE);
612         touches.back().tool_code = getLinuxToolCode(args.pointerProperties[i].toolType);
613         // The field 'orientation' is not used for palm rejection
614         // The fields 'tilt_x' and 'tilt_y' are not used for palm rejection
615         // The field 'reported_tool_type' is not used for palm rejection
616         touches.back().stylus_button = false;
617     }
618     return touches;
619 }
620 
detectPalmPointers(const NotifyMotionArgs & args)621 std::set<int32_t> PalmRejector::detectPalmPointers(const NotifyMotionArgs& args) {
622     std::bitset<::ui::kNumTouchEvdevSlots> slotsToHold;
623     std::bitset<::ui::kNumTouchEvdevSlots> slotsToSuppress;
624 
625     // Store the slot state before we call getTouches and update it. This way, we can find
626     // the slots that have been removed due to the incoming event.
627     SlotState oldSlotState = mSlotState;
628     mSlotState.update(args);
629 
630     std::vector<::ui::InProgressTouchEvdev> touches =
631             getTouches(args, mDeviceInfo, oldSlotState, mSlotState);
632     ::base::TimeTicks chromeTimestamp = toChromeTimestamp(args.eventTime);
633 
634     if (DEBUG_MODEL) {
635         std::stringstream touchesStream;
636         for (const ::ui::InProgressTouchEvdev& touch : touches) {
637             touchesStream << touch.tracking_id << " : " << touch << "\n";
638         }
639         ALOGD("Filter: touches = %s", touchesStream.str().c_str());
640     }
641 
642     mPalmDetectionFilter->Filter(touches, chromeTimestamp, &slotsToHold, &slotsToSuppress);
643 
644     ALOGD_IF(DEBUG_MODEL, "Response: slotsToHold = %s, slotsToSuppress = %s",
645              slotsToHold.to_string().c_str(), slotsToSuppress.to_string().c_str());
646 
647     // Now that we know which slots should be suppressed, let's convert those to pointer id's.
648     std::set<int32_t> newSuppressedIds;
649     for (size_t i = 0; i < args.pointerCount; i++) {
650         const int32_t pointerId = args.pointerProperties[i].id;
651         std::optional<size_t> slot = oldSlotState.getSlotForPointerId(pointerId);
652         if (!slot) {
653             slot = mSlotState.getSlotForPointerId(pointerId);
654             LOG_ALWAYS_FATAL_IF(!slot, "Could not find slot for pointer id %" PRId32, pointerId);
655         }
656         if (slotsToSuppress.test(*slot)) {
657             newSuppressedIds.insert(pointerId);
658         }
659     }
660     return newSuppressedIds;
661 }
662 
processMotion(const NotifyMotionArgs & args)663 std::vector<NotifyMotionArgs> PalmRejector::processMotion(const NotifyMotionArgs& args) {
664     if (mPalmDetectionFilter == nullptr) {
665         return {args};
666     }
667     const bool skipThisEvent = args.action == AMOTION_EVENT_ACTION_HOVER_ENTER ||
668             args.action == AMOTION_EVENT_ACTION_HOVER_MOVE ||
669             args.action == AMOTION_EVENT_ACTION_HOVER_EXIT ||
670             args.action == AMOTION_EVENT_ACTION_BUTTON_PRESS ||
671             args.action == AMOTION_EVENT_ACTION_BUTTON_RELEASE ||
672             args.action == AMOTION_EVENT_ACTION_SCROLL;
673     if (skipThisEvent) {
674         // Lets not process hover events, button events, or scroll for now.
675         return {args};
676     }
677     if (args.action == AMOTION_EVENT_ACTION_DOWN) {
678         mSuppressedPointerIds.clear();
679     }
680 
681     std::set<int32_t> oldSuppressedIds;
682     std::swap(oldSuppressedIds, mSuppressedPointerIds);
683 
684     std::optional<NotifyMotionArgs> touchOnlyArgs = removeStylusPointerIds(args);
685     if (touchOnlyArgs) {
686         mSuppressedPointerIds = detectPalmPointers(*touchOnlyArgs);
687     } else {
688         // This is a stylus-only event.
689         // We can skip this event and just keep the suppressed pointer ids the same as before.
690         mSuppressedPointerIds = oldSuppressedIds;
691     }
692 
693     std::vector<NotifyMotionArgs> argsWithoutUnwantedPointers =
694             cancelSuppressedPointers(args, oldSuppressedIds, mSuppressedPointerIds);
695     for (const NotifyMotionArgs& checkArgs : argsWithoutUnwantedPointers) {
696         LOG_ALWAYS_FATAL_IF(checkArgs.action == ACTION_UNKNOWN, "%s", checkArgs.dump().c_str());
697     }
698 
699     // Only log if new pointers are getting rejected. That means mSuppressedPointerIds is not a
700     // subset of oldSuppressedIds.
701     if (!std::includes(oldSuppressedIds.begin(), oldSuppressedIds.end(),
702                        mSuppressedPointerIds.begin(), mSuppressedPointerIds.end())) {
703         ALOGI("Palm detected, removing pointer ids %s after %" PRId64 "ms from %s",
704               dumpSet(mSuppressedPointerIds).c_str(), ns2ms(args.eventTime - args.downTime),
705               args.dump().c_str());
706     }
707 
708     return argsWithoutUnwantedPointers;
709 }
710 
getPalmFilterDeviceInfo() const711 const AndroidPalmFilterDeviceInfo& PalmRejector::getPalmFilterDeviceInfo() const {
712     return mDeviceInfo;
713 }
714 
dump() const715 std::string PalmRejector::dump() const {
716     std::string out;
717     out += "mDeviceInfo:\n";
718     std::stringstream deviceInfo;
719     deviceInfo << mDeviceInfo << ", touch_major_res=" << mDeviceInfo.touch_major_res
720                << ", touch_minor_res=" << mDeviceInfo.touch_minor_res << "\n";
721     out += addLinePrefix(deviceInfo.str(), "  ");
722     out += "mSlotState:\n";
723     out += addLinePrefix(mSlotState.dump(), "  ");
724     out += "mSuppressedPointerIds: ";
725     out += dumpSet(mSuppressedPointerIds) + "\n";
726     std::stringstream state;
727     state << *mSharedPalmState;
728     out += "mSharedPalmState: " + state.str() + "\n";
729     std::stringstream filter;
730     filter << static_cast<const PalmFilterImplementation&>(*mPalmDetectionFilter);
731     out += "mPalmDetectionFilter:\n";
732     out += addLinePrefix(filter.str(), "  ") + "\n";
733     return out;
734 }
735 
736 } // namespace android
737