• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 
17 #define LOG_TAG "InputChannelTest"
18 
19 #include "../includes/common.h"
20 
21 #include <android-base/stringprintf.h>
22 #include <input/InputTransport.h>
23 
24 using namespace android;
25 using android::base::StringPrintf;
26 
memoryAsHexString(const void * const address,size_t numBytes)27 static std::string memoryAsHexString(const void* const address, size_t numBytes) {
28     std::string str;
29     for (size_t i = 0; i < numBytes; i++) {
30         str += StringPrintf("%02X ", static_cast<const uint8_t* const>(address)[i]);
31     }
32     return str;
33 }
34 
35 /**
36  * There could be non-zero bytes in-between InputMessage fields. Force-initialize the entire
37  * memory to zero, then only copy the valid bytes on a per-field basis.
38  * Input: message msg
39  * Output: cleaned message outMsg
40  */
sanitizeMessage(const InputMessage & msg,InputMessage * outMsg)41 static void sanitizeMessage(const InputMessage& msg, InputMessage* outMsg) {
42     memset(outMsg, 0, sizeof(*outMsg));
43 
44     // Write the header
45     outMsg->header.type = msg.header.type;
46     outMsg->header.seq = msg.header.seq;
47 
48     // Write the body
49     switch(msg.header.type) {
50         case InputMessage::Type::KEY: {
51             // int32_t eventId
52             outMsg->body.key.eventId = msg.body.key.eventId;
53             // nsecs_t eventTime
54             outMsg->body.key.eventTime = msg.body.key.eventTime;
55             // int32_t deviceId
56             outMsg->body.key.deviceId = msg.body.key.deviceId;
57             // int32_t source
58             outMsg->body.key.source = msg.body.key.source;
59             // int32_t displayId
60             outMsg->body.key.displayId = msg.body.key.displayId;
61             // std::array<uint8_t, 32> hmac
62             outMsg->body.key.hmac = msg.body.key.hmac;
63             // int32_t action
64             outMsg->body.key.action = msg.body.key.action;
65             // int32_t flags
66             outMsg->body.key.flags = msg.body.key.flags;
67             // int32_t keyCode
68             outMsg->body.key.keyCode = msg.body.key.keyCode;
69             // int32_t scanCode
70             outMsg->body.key.scanCode = msg.body.key.scanCode;
71             // int32_t metaState
72             outMsg->body.key.metaState = msg.body.key.metaState;
73             // int32_t repeatCount
74             outMsg->body.key.repeatCount = msg.body.key.repeatCount;
75             // nsecs_t downTime
76             outMsg->body.key.downTime = msg.body.key.downTime;
77             break;
78         }
79         case InputMessage::Type::MOTION: {
80             // int32_t eventId
81             outMsg->body.motion.eventId = msg.body.key.eventId;
82             // nsecs_t eventTime
83             outMsg->body.motion.eventTime = msg.body.motion.eventTime;
84             // int32_t deviceId
85             outMsg->body.motion.deviceId = msg.body.motion.deviceId;
86             // int32_t source
87             outMsg->body.motion.source = msg.body.motion.source;
88             // int32_t displayId
89             outMsg->body.motion.displayId = msg.body.motion.displayId;
90             // std::array<uint8_t, 32> hmac
91             outMsg->body.motion.hmac = msg.body.motion.hmac;
92             // int32_t action
93             outMsg->body.motion.action = msg.body.motion.action;
94             // int32_t actionButton
95             outMsg->body.motion.actionButton = msg.body.motion.actionButton;
96             // int32_t flags
97             outMsg->body.motion.flags = msg.body.motion.flags;
98             // int32_t metaState
99             outMsg->body.motion.metaState = msg.body.motion.metaState;
100             // int32_t buttonState
101             outMsg->body.motion.buttonState = msg.body.motion.buttonState;
102             // MotionClassification classification
103             outMsg->body.motion.classification = msg.body.motion.classification;
104             // int32_t edgeFlags
105             outMsg->body.motion.edgeFlags = msg.body.motion.edgeFlags;
106             // nsecs_t downTime
107             outMsg->body.motion.downTime = msg.body.motion.downTime;
108             // float dsdx
109             outMsg->body.motion.dsdx = msg.body.motion.dsdx;
110             // float dtdx
111             outMsg->body.motion.dtdx = msg.body.motion.dtdx;
112             // float dtdy
113             outMsg->body.motion.dtdy = msg.body.motion.dtdy;
114             // float dsdy
115             outMsg->body.motion.dsdy = msg.body.motion.dsdy;
116             // float tx
117             outMsg->body.motion.tx = msg.body.motion.tx;
118             // float ty
119             outMsg->body.motion.ty = msg.body.motion.ty;
120             // float xPrecision
121             outMsg->body.motion.xPrecision = msg.body.motion.xPrecision;
122             // float yPrecision
123             outMsg->body.motion.yPrecision = msg.body.motion.yPrecision;
124             // float xCursorPosition
125             outMsg->body.motion.xCursorPosition = msg.body.motion.xCursorPosition;
126             // float yCursorPosition
127             outMsg->body.motion.yCursorPosition = msg.body.motion.yCursorPosition;
128             // int32_t displayW
129             outMsg->body.motion.displayWidth = msg.body.motion.displayWidth;
130             // int32_t displayH
131             outMsg->body.motion.displayHeight = msg.body.motion.displayHeight;
132             // uint32_t pointerCount
133             outMsg->body.motion.pointerCount = msg.body.motion.pointerCount;
134             //struct Pointer pointers[MAX_POINTERS]
135             for (size_t i = 0; i < msg.body.motion.pointerCount; i++) {
136                 // PointerProperties properties
137                 outMsg->body.motion.pointers[i].properties.id =
138                         msg.body.motion.pointers[i].properties.id;
139                 outMsg->body.motion.pointers[i].properties.toolType =
140                         msg.body.motion.pointers[i].properties.toolType;
141                 // PointerCoords coords
142                 outMsg->body.motion.pointers[i].coords.bits =
143                         msg.body.motion.pointers[i].coords.bits;
144                 const uint32_t count = BitSet64::count(msg.body.motion.pointers[i].coords.bits);
145                 memcpy(&outMsg->body.motion.pointers[i].coords.values[0],
146                         &msg.body.motion.pointers[i].coords.values[0],
147                         count * sizeof(msg.body.motion.pointers[i].coords.values[0]));
148             }
149             break;
150         }
151         case InputMessage::Type::FINISHED: {
152             outMsg->body.finished.handled = msg.body.finished.handled;
153             outMsg->body.finished.consumeTime = msg.body.finished.consumeTime;
154             break;
155         }
156         case InputMessage::Type::FOCUS: {
157             outMsg->body.focus.eventId = msg.body.focus.eventId;
158             outMsg->body.focus.hasFocus = msg.body.focus.hasFocus;
159             outMsg->body.focus.inTouchMode = msg.body.focus.inTouchMode;
160             break;
161         }
162         case InputMessage::Type::CAPTURE: {
163             outMsg->body.capture.eventId = msg.body.capture.eventId;
164             outMsg->body.capture.pointerCaptureEnabled = msg.body.capture.pointerCaptureEnabled;
165             break;
166         }
167         case InputMessage::Type::DRAG: {
168             outMsg->body.capture.eventId = msg.body.capture.eventId;
169             outMsg->body.drag.isExiting = msg.body.drag.isExiting;
170             outMsg->body.drag.x = msg.body.drag.x;
171             outMsg->body.drag.y = msg.body.drag.y;
172             break;
173         }
174         case InputMessage::Type::TIMELINE: {
175             outMsg->body.timeline.eventId = msg.body.timeline.eventId;
176             outMsg->body.timeline.graphicsTimeline = msg.body.timeline.graphicsTimeline;
177             break;
178         }
179     }
180 }
181 
makeMessageValid(InputMessage & msg)182 static void makeMessageValid(InputMessage& msg) {
183     InputMessage::Type type = msg.header.type;
184     if (type == InputMessage::Type::MOTION) {
185         // Message is considered invalid if it has more than MAX_POINTERS pointers.
186         msg.body.motion.pointerCount = MAX_POINTERS;
187     }
188     if (type == InputMessage::Type::TIMELINE) {
189         // Message is considered invalid if presentTime <= gpuCompletedTime
190         msg.body.timeline.graphicsTimeline[GraphicsTimeline::GPU_COMPLETED_TIME] = 10;
191         msg.body.timeline.graphicsTimeline[GraphicsTimeline::PRESENT_TIME] = 20;
192     }
193 }
194 
195 /**
196  * Return false if vulnerability is found for a given message type
197  */
checkMessage(InputChannel & server,InputChannel & client,InputMessage::Type type)198 static bool checkMessage(InputChannel& server, InputChannel& client, InputMessage::Type type) {
199     InputMessage serverMsg;
200     // Set all potentially uninitialized bytes to 1, for easier comparison
201 
202     memset(&serverMsg, 1, sizeof(serverMsg));
203     serverMsg.header.type = type;
204     makeMessageValid(serverMsg);
205     status_t result = server.sendMessage(&serverMsg);
206     if (result != OK) {
207         ALOGE("Could not send message to the input channel");
208         return false;
209     }
210 
211     InputMessage clientMsg;
212     result = client.receiveMessage(&clientMsg);
213     if (result != OK) {
214         ALOGE("Could not receive message from the input channel");
215         return false;
216     }
217     if (serverMsg.header.type != clientMsg.header.type) {
218         ALOGE("Types do not match");
219         return false;
220     }
221 
222     InputMessage sanitizedClientMsg;
223     sanitizeMessage(clientMsg, &sanitizedClientMsg);
224     if (memcmp(&clientMsg, &sanitizedClientMsg, clientMsg.size()) != 0) {
225         ALOGE("Client received un-sanitized message");
226         ALOGE("Received message: %s", memoryAsHexString(&clientMsg, clientMsg.size()).c_str());
227         ALOGE("Expected message: %s",
228                 memoryAsHexString(&sanitizedClientMsg, clientMsg.size()).c_str());
229         return false;
230     }
231 
232     return true;
233 }
234 
235 /**
236  * Create an unsanitized message
237  * Send
238  * Receive
239  * Compare the received message to a sanitized expected message
240  * Do this for all message types
241  */
main()242 int main() {
243     std::unique_ptr<InputChannel> server, client;
244 
245     status_t result = InputChannel::openInputChannelPair("channel name", server, client);
246     if (result != OK) {
247         ALOGE("Could not open input channel pair");
248         return 0;
249     }
250 
251     InputMessage::Type types[] = {
252             InputMessage::Type::KEY,      InputMessage::Type::MOTION,  InputMessage::Type::FINISHED,
253             InputMessage::Type::FOCUS,    InputMessage::Type::CAPTURE, InputMessage::Type::DRAG,
254             InputMessage::Type::TIMELINE,
255     };
256     for (InputMessage::Type type : types) {
257         bool success = checkMessage(*server, *client, type);
258         if (!success) {
259             ALOGE("Check message failed for type %i", type);
260             return EXIT_VULNERABLE;
261         }
262     }
263 
264     return 0;
265 }
266