• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#
2# Copyright (C) 2016 The Android Open Source Project
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#      http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16
17import json
18import logging
19import os
20import socket
21import time
22import types
23
24from vts.proto import AndroidSystemControlMessage_pb2 as SysMsg_pb2
25from vts.proto import ComponentSpecificationMessage_pb2 as CompSpecMsg_pb2
26from vts.runners.host import const
27from vts.runners.host import errors
28from vts.utils.python.mirror import mirror_object
29
30from google.protobuf import text_format
31
32TARGET_IP = os.environ.get("TARGET_IP", None)
33TARGET_PORT = os.environ.get("TARGET_PORT", None)
34_DEFAULT_SOCKET_TIMEOUT_SECS = 1800
35_SOCKET_CONN_TIMEOUT_SECS = 60
36_SOCKET_CONN_RETRY_NUMBER = 5
37COMMAND_TYPE_NAME = {1: "LIST_HALS",
38                     2: "SET_HOST_INFO",
39                     101: "CHECK_DRIVER_SERVICE",
40                     102: "LAUNCH_DRIVER_SERVICE",
41                     103: "VTS_AGENT_COMMAND_READ_SPECIFICATION",
42                     201: "LIST_APIS",
43                     202: "CALL_API",
44                     203: "VTS_AGENT_COMMAND_GET_ATTRIBUTE",
45                     301: "VTS_AGENT_COMMAND_EXECUTE_SHELL_COMMAND"}
46
47
48class VtsTcpClient(object):
49    """VTS TCP Client class.
50
51    Attribute:
52        connection: a TCP socket instance.
53        channel: a file to write and read data.
54        _mode: the connection mode (adb_forwarding or ssh_tunnel)
55    """
56
57    def __init__(self, mode="adb_forwarding"):
58        self.connection = None
59        self.channel = None
60        self._mode = mode
61
62    def Connect(self, ip=TARGET_IP, command_port=TARGET_PORT,
63                callback_port=None, retry=_SOCKET_CONN_RETRY_NUMBER):
64        """Connects to a target device.
65
66        Args:
67            ip: string, the IP address of a target device.
68            command_port: int, the TCP port which can be used to connect to
69                          a target device.
70            callback_port: int, the TCP port number of a host-side callback
71                           server.
72            retry: int, the number of times to retry connecting before giving
73                   up.
74
75        Returns:
76            True if success, False otherwise
77
78        Raises:
79            socket.error when the connection fails.
80        """
81        if not command_port:
82            logging.error("ip %s, command_port %s, callback_port %s invalid",
83                          ip, command_port, callback_port)
84            return False
85
86        for i in xrange(retry):
87            try:
88                self.connection = socket.create_connection(
89                    (ip, command_port), _SOCKET_CONN_TIMEOUT_SECS)
90                self.connection.settimeout(_DEFAULT_SOCKET_TIMEOUT_SECS)
91            except socket.error as e:
92                # Wait a bit and retry.
93                logging.exception("Connect failed %s", e)
94                time.sleep(1)
95                if i + 1 == retry:
96                    raise errors.VtsTcpClientCreationError(
97                        "Couldn't connect to %s:%s" % (ip, command_port))
98        self.channel = self.connection.makefile(mode="brw")
99
100        if callback_port is not None:
101            self.SendCommand(SysMsg_pb2.SET_HOST_INFO,
102                             callback_port=callback_port)
103            resp = self.RecvResponse()
104            if (resp.response_code != SysMsg_pb2.SUCCESS):
105                return False
106        return True
107
108    def Disconnect(self):
109        """Disconnects from the target device.
110
111        TODO(yim): Send a msg to the target side to teardown handler session
112        and release memory before closing the socket.
113        """
114        if self.connection is not None:
115            self.channel = None
116            self.connection.close()
117            self.connection = None
118
119    def ListHals(self, base_paths):
120        """RPC to LIST_HALS."""
121        self.SendCommand(SysMsg_pb2.LIST_HALS, paths=base_paths)
122        resp = self.RecvResponse()
123        if (resp.response_code == SysMsg_pb2.SUCCESS):
124            return resp.file_names
125        return None
126
127    def CheckDriverService(self, service_name):
128        """RPC to CHECK_DRIVER_SERVICE."""
129        self.SendCommand(SysMsg_pb2.CHECK_DRIVER_SERVICE,
130                         service_name=service_name)
131        resp = self.RecvResponse()
132        return (resp.response_code == SysMsg_pb2.SUCCESS)
133
134    def LaunchDriverService(self, driver_type, service_name, bits,
135                            file_path=None, target_class=None, target_type=None,
136                            target_version=None, target_package=None,
137                            target_component_name=None,
138                            hw_binder_service_name=None):
139        """RPC to LAUNCH_DRIVER_SERVICE."""
140        logging.info("service_name: %s", service_name)
141        logging.info("file_path: %s", file_path)
142        logging.info("bits: %s", bits)
143        logging.info("driver_type: %s", driver_type)
144        self.SendCommand(SysMsg_pb2.LAUNCH_DRIVER_SERVICE,
145                         driver_type=driver_type,
146                         service_name=service_name,
147                         bits=bits,
148                         file_path=file_path,
149                         target_class=target_class,
150                         target_type=target_type,
151                         target_version=target_version,
152                         target_package=target_package,
153                         target_component_name=target_component_name,
154                         hw_binder_service_name=hw_binder_service_name)
155        resp = self.RecvResponse()
156        logging.info("resp for LAUNCH_DRIVER_SERVICE: %s", resp)
157        return (resp.response_code == SysMsg_pb2.SUCCESS)
158
159    def ListApis(self):
160        """RPC to LIST_APIS."""
161        self.SendCommand(SysMsg_pb2.LIST_APIS)
162        resp = self.RecvResponse()
163        logging.info("resp for LIST_APIS: %s", resp)
164        if (resp.response_code == SysMsg_pb2.SUCCESS):
165            return resp.spec
166        return None
167
168    def GetPythonDataOfVariableSpecMsg(self, var_spec_msg):
169        """Returns the python native data structure for a given message.
170
171        Args:
172            var_spec_msg: VariableSpecificationMessage
173
174        Returns:
175            python native data structure (e.g., string, integer, list).
176
177        Raises:
178            VtsUnsupportedTypeError if unsupported type is specified.
179            VtsMalformedProtoStringError if StringDataValueMessage is
180                not populated.
181        """
182        if var_spec_msg.type == CompSpecMsg_pb2.TYPE_SCALAR:
183            scalar_type = getattr(var_spec_msg, "scalar_type", "")
184            if scalar_type:
185                return getattr(
186                    var_spec_msg.scalar_value, scalar_type)
187        elif var_spec_msg.type == CompSpecMsg_pb2.TYPE_ENUM:
188            scalar_type = getattr(var_spec_msg, "scalar_type", "")
189            if scalar_type:
190                return getattr(
191                    var_spec_msg.scalar_value, scalar_type)
192            else:
193                return var_spec_msg.scalar_value.int32_t
194        elif var_spec_msg.type == CompSpecMsg_pb2.TYPE_STRING:
195            if hasattr(var_spec_msg, "string_value"):
196                return getattr(
197                    var_spec_msg.string_value, "message", "")
198            raise errors.VtsMalformedProtoStringError()
199        elif var_spec_msg.type == CompSpecMsg_pb2.TYPE_STRUCT:
200            result = {}
201            index = 1
202            for struct_value in var_spec_msg.struct_value:
203                if len(struct_value.name) > 0:
204                    result[struct_value.name] = self.GetPythonDataOfVariableSpecMsg(
205                        struct_value)
206                else:
207                    result["attribute%d" % index] = self.GetPythonDataOfVariableSpecMsg(
208                        struct_value)
209                index += 1
210            return result
211        elif var_spec_msg.type == CompSpecMsg_pb2.TYPE_UNION:
212            result = VtsReturnValueObject()
213            index = 1
214            for union_value in var_spec_msg.union_value:
215                if len(union_value.name) > 0:
216                    result[union_value.name] = self.GetPythonDataOfVariableSpecMsg(
217                        union_value)
218                else:
219                    result["attribute%d" % index] = self.GetPythonDataOfVariableSpecMsg(
220                        union_value)
221                index += 1
222            return result
223        elif (var_spec_msg.type == CompSpecMsg_pb2.TYPE_VECTOR or
224              var_spec_msg.type == CompSpecMsg_pb2.TYPE_ARRAY):
225            result = []
226            for vector_value in var_spec_msg.vector_value:
227                result.append(self.GetPythonDataOfVariableSpecMsg(vector_value))
228            return result
229
230        raise errors.VtsUnsupportedTypeError(
231            "unsupported type %s" % var_spec_msg.type)
232
233    def CallApi(self, arg, caller_uid=None):
234        """RPC to CALL_API."""
235        self.SendCommand(SysMsg_pb2.CALL_API, arg=arg, caller_uid=caller_uid)
236        resp = self.RecvResponse()
237        resp_code = resp.response_code
238        if (resp_code == SysMsg_pb2.SUCCESS):
239            result = CompSpecMsg_pb2.FunctionSpecificationMessage()
240            if resp.result == "error":
241                raise errors.VtsTcpCommunicationError(
242                    "API call error by the VTS driver.")
243            try:
244                text_format.Merge(resp.result, result)
245            except text_format.ParseError as e:
246                logging.exception(e)
247                logging.error("Paring error\n%s", resp.result)
248            if result.return_type.type == CompSpecMsg_pb2.TYPE_SUBMODULE:
249                logging.info("returned a submodule spec")
250                logging.info("spec: %s", result.return_type_submodule_spec)
251                return mirror_object.MirrorObject(
252                     self, result.return_type_submodule_spec, None)
253
254            logging.info("result: %s", result.return_type_hidl)
255            if len(result.return_type_hidl) == 1:
256                result_value = self.GetPythonDataOfVariableSpecMsg(
257                    result.return_type_hidl[0])
258            elif len(result.return_type_hidl) > 1:
259                result_value = []
260                for return_type_hidl in result.return_type_hidl:
261                    result_value.append(self.GetPythonDataOfVariableSpecMsg(
262                        return_type_hidl))
263            else:  # For non-HIDL return value
264                if hasattr(result, "return_type"):
265                    result_value = result
266                else:
267                    result_value = None
268
269            if hasattr(result, "raw_coverage_data"):
270                return result_value, {"coverage": result.raw_coverage_data}
271            else:
272                return result_value
273
274        logging.error("NOTICE - Likely a crash discovery!")
275        logging.error("SysMsg_pb2.SUCCESS is %s", SysMsg_pb2.SUCCESS)
276        raise errors.VtsTcpCommunicationError(
277            "RPC Error, response code for %s is %s" % (arg, resp_code))
278
279    def GetAttribute(self, arg):
280        """RPC to VTS_AGENT_COMMAND_GET_ATTRIBUTE."""
281        self.SendCommand(SysMsg_pb2.VTS_AGENT_COMMAND_GET_ATTRIBUTE, arg=arg)
282        resp = self.RecvResponse()
283        resp_code = resp.response_code
284        if (resp_code == SysMsg_pb2.SUCCESS):
285            result = CompSpecMsg_pb2.FunctionSpecificationMessage()
286            if resp.result == "error":
287                raise errors.VtsTcpCommunicationError(
288                    "Get attribute request failed on target.")
289            try:
290                text_format.Merge(resp.result, result)
291            except text_format.ParseError as e:
292                logging.exception(e)
293                logging.error("Paring error\n%s", resp.result)
294            if result.return_type.type == CompSpecMsg_pb2.TYPE_SUBMODULE:
295                logging.info("returned a submodule spec")
296                logging.info("spec: %s", result.return_type_submodule_spec)
297                return mirror_object.MirrorObject(self,
298                                           result.return_type_submodule_spec,
299                                           None)
300            elif result.return_type.type == CompSpecMsg_pb2.TYPE_SCALAR:
301                return getattr(result.return_type.scalar_value,
302                               result.return_type.scalar_type)
303            return result
304        logging.error("NOTICE - Likely a crash discovery!")
305        logging.error("SysMsg_pb2.SUCCESS is %s", SysMsg_pb2.SUCCESS)
306        raise errors.VtsTcpCommunicationError(
307            "RPC Error, response code for %s is %s" % (arg, resp_code))
308
309    def ExecuteShellCommand(self, command):
310        """RPC to VTS_AGENT_COMMAND_EXECUTE_SHELL_COMMAND."""
311        self.SendCommand(
312            SysMsg_pb2.VTS_AGENT_COMMAND_EXECUTE_SHELL_COMMAND,
313            shell_command=command)
314        resp = self.RecvResponse(retries=2)
315        logging.info("resp for VTS_AGENT_COMMAND_EXECUTE_SHELL_COMMAND: %s",
316                     resp)
317
318        stdout = None
319        stderr = None
320        exit_code = None
321
322        if not resp:
323            logging.error("resp is: %s.", resp)
324        elif resp.response_code != SysMsg_pb2.SUCCESS:
325            logging.error("resp response code is not success: %s.", resp.response_code)
326        else:
327            stdout = resp.stdout
328            stderr = resp.stderr
329            exit_code = resp.exit_code
330
331        return {const.STDOUT: stdout,
332                const.STDERR: stderr,
333                const.EXIT_CODE: exit_code,
334                }
335
336    def Ping(self):
337        """RPC to send a PING request.
338
339        Returns:
340            True if the agent is alive, False otherwise.
341        """
342        self.SendCommand(SysMsg_pb2.PING)
343        resp = self.RecvResponse()
344        logging.info("resp for PING: %s", resp)
345        if resp is not None and resp.response_code == SysMsg_pb2.SUCCESS:
346            return True
347        return False
348
349    def ReadSpecification(self, interface_name, target_class, target_type,
350                          target_version, target_package, recursive = False):
351        """RPC to VTS_AGENT_COMMAND_READ_SPECIFICATION.
352
353        Args:
354            other args: see SendCommand
355            recursive: boolean, set to recursively read the imported
356                       specification(s) and return the merged one.
357        """
358        self.SendCommand(
359            SysMsg_pb2.VTS_AGENT_COMMAND_READ_SPECIFICATION,
360            service_name=interface_name,
361            target_class=target_class,
362            target_type=target_type,
363            target_version=target_version,
364            target_package=target_package)
365        resp = self.RecvResponse(retries=2)
366        logging.info("resp for VTS_AGENT_COMMAND_EXECUTE_READ_INTERFACE: %s",
367                     resp)
368        logging.info("proto: %s",
369                     resp.result)
370        result = CompSpecMsg_pb2.ComponentSpecificationMessage()
371        if resp.result == "error":
372            raise errors.VtsTcpCommunicationError(
373                "API call error by the VTS driver.")
374        try:
375            text_format.Merge(resp.result, result)
376        except text_format.ParseError as e:
377            logging.exception(e)
378            logging.error("Paring error\n%s", resp.result)
379
380        if recursive and hasattr(result, "import"):
381            for imported_interface in getattr(result, "import"):
382                imported_result = self.ReadSpecification(
383                    imported_interface.split("::")[1],
384                    # TODO(yim): derive target_class and
385                    # target_type from package path or remove them
386                    msg.component_class if target_class is None else target_class,
387                    msg.component_type if target_type is None else target_type,
388                    float(imported_interface.split("@")[1].split("::")[0]),
389                    imported_interface.split("@")[0])
390                result.MergeFrom(imported_result)
391
392        return result
393
394    def SendCommand(self,
395                    command_type,
396                    paths=None,
397                    file_path=None,
398                    bits=None,
399                    target_class=None,
400                    target_type=None,
401                    target_version=None,
402                    target_package=None,
403                    target_component_name=None,
404                    hw_binder_service_name=None,
405                    module_name=None,
406                    service_name=None,
407                    callback_port=None,
408                    driver_type=None,
409                    shell_command=None,
410                    caller_uid=None,
411                    arg=None):
412        """Sends a command.
413
414        Args:
415            command_type: integer, the command type.
416            each of the other args are to fill in a field in
417            AndroidSystemControlCommandMessage.
418        """
419        if not self.channel:
420            raise errors.VtsTcpCommunicationError(
421                "channel is None, unable to send command.")
422
423        command_msg = SysMsg_pb2.AndroidSystemControlCommandMessage()
424        command_msg.command_type = command_type
425        logging.info("sending a command (type %s)",
426                     COMMAND_TYPE_NAME[command_type])
427        if command_type == 202:
428            logging.info("target API: %s", arg)
429
430        if target_class is not None:
431            command_msg.target_class = target_class
432
433        if target_type is not None:
434            command_msg.target_type = target_type
435
436        if target_version is not None:
437            command_msg.target_version = int(target_version * 100)
438
439        if target_package is not None:
440            command_msg.target_package = target_package
441
442        if target_component_name is not None:
443            command_msg.target_component_name = target_component_name
444
445        if hw_binder_service_name is not None:
446            command_msg.hw_binder_service_name = hw_binder_service_name
447
448        if module_name is not None:
449            command_msg.module_name = module_name
450
451        if service_name is not None:
452            command_msg.service_name = service_name
453
454        if driver_type is not None:
455            command_msg.driver_type = driver_type
456
457        if paths is not None:
458            command_msg.paths.extend(paths)
459
460        if file_path is not None:
461            command_msg.file_path = file_path
462
463        if bits is not None:
464            command_msg.bits = bits
465
466        if callback_port is not None:
467            command_msg.callback_port = callback_port
468
469        if caller_uid is not None:
470            command_msg.driver_caller_uid = caller_uid
471
472        if arg is not None:
473            command_msg.arg = arg
474
475        if shell_command is not None:
476            if isinstance(shell_command, types.ListType):
477                command_msg.shell_command.extend(shell_command)
478            else:
479                command_msg.shell_command.append(shell_command)
480
481        logging.info("command %s" % command_msg)
482        message = command_msg.SerializeToString()
483        message_len = len(message)
484        logging.debug("sending %d bytes", message_len)
485        self.channel.write(str(message_len) + b'\n')
486        self.channel.write(message)
487        self.channel.flush()
488
489    def RecvResponse(self, retries=0):
490        """Receives and parses the response, and returns the relevant ResponseMessage.
491
492        Args:
493            retries: an integer indicating the max number of retries in case of
494                     session timeout error.
495        """
496        for index in xrange(1 + retries):
497            try:
498                if index != 0:
499                    logging.info("retrying...")
500                header = self.channel.readline().strip("\n")
501                length = int(header) if header else 0
502                logging.info("resp %d bytes", length)
503                data = self.channel.read(length)
504                response_msg = SysMsg_pb2.AndroidSystemControlResponseMessage()
505                response_msg.ParseFromString(data)
506                logging.debug("Response %s", "success"
507                              if response_msg.response_code == SysMsg_pb2.SUCCESS
508                              else "fail")
509                return response_msg
510            except socket.timeout as e:
511                logging.exception(e)
512        return None
513