1#!/usr/bin/env python 2# 3# Copyright 2018 - The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16"""Common operations between managing GCE and Cuttlefish devices. 17 18This module provides the common operations between managing GCE (device_driver) 19and Cuttlefish (create_cuttlefish_action) devices. Should not be called 20directly. 21""" 22 23import logging 24import os 25 26from acloud import errors 27from acloud.public import avd 28from acloud.public import report 29from acloud.internal import constants 30from acloud.internal.lib import utils 31from acloud.internal.lib.adb_tools import AdbTools 32 33 34logger = logging.getLogger(__name__) 35_GCE_QUOTA_ERROR_KEYWORDS = [ 36 "Quota exceeded for quota", 37 "ZONE_RESOURCE_POOL_EXHAUSTED", 38 "ZONE_RESOURCE_POOL_EXHAUSTED_WITH_DETAILS"] 39_DICT_ERROR_TYPE = { 40 constants.STAGE_INIT: constants.ACLOUD_INIT_ERROR, 41 constants.STAGE_GCE: constants.ACLOUD_CREATE_GCE_ERROR, 42 constants.STAGE_SSH_CONNECT: constants.ACLOUD_SSH_CONNECT_ERROR, 43 constants.STAGE_ARTIFACT: constants.ACLOUD_DOWNLOAD_ARTIFACT_ERROR, 44 constants.STAGE_BOOT_UP: constants.ACLOUD_BOOT_UP_ERROR, 45} 46 47 48def CreateSshKeyPairIfNecessary(cfg): 49 """Create ssh key pair if necessary. 50 51 Args: 52 cfg: An Acloudconfig instance. 53 54 Raises: 55 error.DriverError: If it falls into an unexpected condition. 56 """ 57 if not cfg.ssh_public_key_path: 58 logger.warning( 59 "ssh_public_key_path is not specified in acloud config. " 60 "Project-wide public key will " 61 "be used when creating AVD instances. " 62 "Please ensure you have the correct private half of " 63 "a project-wide public key if you want to ssh into the " 64 "instances after creation.") 65 elif cfg.ssh_public_key_path and not cfg.ssh_private_key_path: 66 logger.warning( 67 "Only ssh_public_key_path is specified in acloud config, " 68 "but ssh_private_key_path is missing. " 69 "Please ensure you have the correct private half " 70 "if you want to ssh into the instances after creation.") 71 elif cfg.ssh_public_key_path and cfg.ssh_private_key_path: 72 utils.CreateSshKeyPairIfNotExist(cfg.ssh_private_key_path, 73 cfg.ssh_public_key_path) 74 else: 75 # Should never reach here. 76 raise errors.DriverError( 77 "Unexpected error in CreateSshKeyPairIfNecessary") 78 79 80class DevicePool: 81 """A class that manages a pool of virtual devices. 82 83 Attributes: 84 devices: A list of devices in the pool. 85 """ 86 87 def __init__(self, device_factory, devices=None): 88 """Constructs a new DevicePool. 89 90 Args: 91 device_factory: A device factory capable of producing a goldfish or 92 cuttlefish device. The device factory must expose an attribute with 93 the credentials that can be used to retrieve information from the 94 constructed device. 95 devices: List of devices managed by this pool. 96 """ 97 self._devices = devices or [] 98 self._device_factory = device_factory 99 self._compute_client = device_factory.GetComputeClient() 100 101 def CreateDevices(self, num): 102 """Creates |num| devices for given build_target and build_id. 103 104 Args: 105 num: Number of devices to create. 106 """ 107 # Create host instances for cuttlefish/goldfish device. 108 # Currently one instance supports only 1 device. 109 for _ in range(num): 110 instance = self._device_factory.CreateInstance() 111 ip = self._compute_client.GetInstanceIP(instance) 112 time_info = self._compute_client.execution_time if hasattr( 113 self._compute_client, "execution_time") else {} 114 stage = self._compute_client.stage if hasattr( 115 self._compute_client, "stage") else 0 116 openwrt = self._compute_client.openwrt if hasattr( 117 self._compute_client, "openwrt") else False 118 self.devices.append( 119 avd.AndroidVirtualDevice(ip=ip, instance_name=instance, 120 time_info=time_info, stage=stage, 121 openwrt=openwrt)) 122 123 @utils.TimeExecute(function_description="Waiting for AVD(s) to boot up", 124 result_evaluator=utils.BootEvaluator) 125 def WaitForBoot(self, boot_timeout_secs): 126 """Waits for all devices to boot up. 127 128 Args: 129 boot_timeout_secs: Integer, the maximum time in seconds used to 130 wait for the AVD to boot. 131 132 Returns: 133 A dictionary that contains all the failures. 134 The key is the name of the instance that fails to boot, 135 and the value is an errors.DeviceBootError object. 136 """ 137 failures = {} 138 for device in self._devices: 139 try: 140 self._compute_client.WaitForBoot(device.instance_name, boot_timeout_secs) 141 except errors.DeviceBootError as e: 142 failures[device.instance_name] = e 143 return failures 144 145 def UpdateReport(self, reporter): 146 """Update report from compute client. 147 148 Args: 149 reporter: Report object. 150 """ 151 reporter.UpdateData(self._compute_client.dict_report) 152 153 def CollectSerialPortLogs(self, output_file, 154 port=constants.DEFAULT_SERIAL_PORT): 155 """Tar the instance serial logs into specified output_file. 156 157 Args: 158 output_file: String, the output tar file path 159 port: The serial port number to be collected 160 """ 161 # For emulator, the serial log is the virtual host serial log. 162 # For GCE AVD device, the serial log is the AVD device serial log. 163 with utils.TempDir() as tempdir: 164 src_dict = {} 165 for device in self._devices: 166 logger.info("Store instance %s serial port %s output to %s", 167 device.instance_name, port, output_file) 168 serial_log = self._compute_client.GetSerialPortOutput( 169 instance=device.instance_name, port=port) 170 file_name = "%s_serial_%s.log" % (device.instance_name, port) 171 file_path = os.path.join(tempdir, file_name) 172 src_dict[file_path] = file_name 173 with open(file_path, "w") as f: 174 f.write(serial_log.encode("utf-8")) 175 utils.MakeTarFile(src_dict, output_file) 176 177 def SetDeviceBuildInfo(self): 178 """Add devices build info.""" 179 for device in self._devices: 180 device.build_info = self._device_factory.GetBuildInfoDict() 181 182 @property 183 def devices(self): 184 """Returns a list of devices in the pool. 185 186 Returns: 187 A list of devices in the pool. 188 """ 189 return self._devices 190 191def _GetErrorType(error): 192 """Get proper error type from the exception error. 193 194 Args: 195 error: errors object. 196 197 Returns: 198 String of error type. e.g. "ACLOUD_BOOT_UP_ERROR". 199 """ 200 if isinstance(error, errors.CheckGCEZonesQuotaError): 201 return constants.GCE_QUOTA_ERROR 202 if isinstance(error, errors.DownloadArtifactError): 203 return constants.ACLOUD_DOWNLOAD_ARTIFACT_ERROR 204 if isinstance(error, errors.DeviceConnectionError): 205 return constants.ACLOUD_SSH_CONNECT_ERROR 206 for keyword in _GCE_QUOTA_ERROR_KEYWORDS: 207 if keyword in str(error): 208 return constants.GCE_QUOTA_ERROR 209 return constants.ACLOUD_UNKNOWN_ERROR 210 211def _GetAdbPort(avd_type, base_instance_num): 212 """Get Adb port according to avd_type and device offset. 213 214 Args: 215 avd_type: String, the AVD type(cuttlefish, goldfish...). 216 base_instance_num: int, device offset. 217 218 Returns: 219 int, adb port. 220 """ 221 if avd_type in utils.AVD_PORT_DICT: 222 return utils.AVD_PORT_DICT[avd_type].adb_port + base_instance_num - 1 223 return None 224 225# pylint: disable=too-many-locals,unused-argument,too-many-branches 226def CreateDevices(command, cfg, device_factory, num, avd_type, 227 report_internal_ip=False, autoconnect=False, 228 serial_log_file=None, client_adb_port=None, 229 boot_timeout_secs=None, unlock_screen=False, 230 wait_for_boot=True, connect_webrtc=False, 231 ssh_private_key_path=None, 232 ssh_user=constants.GCE_USER): 233 """Create a set of devices using the given factory. 234 235 Main jobs in create devices. 236 1. Create GCE instance: Launch instance in GCP(Google Cloud Platform). 237 2. Starting up AVD: Wait device boot up. 238 239 Args: 240 command: The name of the command, used for reporting. 241 cfg: An AcloudConfig instance. 242 device_factory: A factory capable of producing a single device. 243 num: The number of devices to create. 244 avd_type: String, the AVD type(cuttlefish, goldfish...). 245 report_internal_ip: Boolean to report the internal ip instead of 246 external ip. 247 serial_log_file: String, the file path to tar the serial logs. 248 autoconnect: Boolean, whether to auto connect to device. 249 client_adb_port: Integer, Specify port for adb forwarding. 250 boot_timeout_secs: Integer, boot timeout secs. 251 unlock_screen: Boolean, whether to unlock screen after invoke vnc client. 252 wait_for_boot: Boolean, True to check serial log include boot up 253 message. 254 connect_webrtc: Boolean, whether to auto connect webrtc to device. 255 ssh_private_key_path: String, the private key for SSH tunneling. 256 ssh_user: String, the user name for SSH tunneling. 257 258 Raises: 259 errors: Create instance fail. 260 261 Returns: 262 A Report instance. 263 """ 264 reporter = report.Report(command=command) 265 try: 266 CreateSshKeyPairIfNecessary(cfg) 267 device_pool = DevicePool(device_factory) 268 device_pool.CreateDevices(num) 269 device_pool.SetDeviceBuildInfo() 270 if wait_for_boot: 271 failures = device_pool.WaitForBoot(boot_timeout_secs) 272 else: 273 failures = device_factory.GetFailures() 274 275 if failures: 276 reporter.SetStatus(report.Status.BOOT_FAIL) 277 else: 278 reporter.SetStatus(report.Status.SUCCESS) 279 280 # Collect logs 281 logs = device_factory.GetLogs() 282 if serial_log_file: 283 device_pool.CollectSerialPortLogs( 284 serial_log_file, port=constants.DEFAULT_SERIAL_PORT) 285 286 device_pool.UpdateReport(reporter) 287 # Write result to report. 288 for device in device_pool.devices: 289 ip = (device.ip.internal if report_internal_ip 290 else device.ip.external) 291 base_instance_num = 1 292 if constants.BASE_INSTANCE_NUM in device_pool._compute_client.dict_report: 293 base_instance_num = device_pool._compute_client.dict_report[constants.BASE_INSTANCE_NUM] 294 adb_port = _GetAdbPort( 295 avd_type, 296 base_instance_num 297 ) 298 device_dict = { 299 "ip": ip + (":" + str(adb_port) if adb_port else ""), 300 "instance_name": device.instance_name 301 } 302 if device.build_info: 303 device_dict.update(device.build_info) 304 if device.time_info: 305 device_dict.update(device.time_info) 306 if device.openwrt: 307 device_dict.update(device_factory.GetOpenWrtInfoDict()) 308 if autoconnect and reporter.status == report.Status.SUCCESS: 309 forwarded_ports = utils.AutoConnect( 310 ip_addr=ip, 311 rsa_key_file=(ssh_private_key_path or 312 cfg.ssh_private_key_path), 313 target_vnc_port=utils.AVD_PORT_DICT[avd_type].vnc_port, 314 target_adb_port=adb_port, 315 ssh_user=ssh_user, 316 client_adb_port=client_adb_port, 317 extra_args_ssh_tunnel=cfg.extra_args_ssh_tunnel) 318 device_dict[constants.VNC_PORT] = forwarded_ports.vnc_port 319 device_dict[constants.ADB_PORT] = forwarded_ports.adb_port 320 device_dict[constants.DEVICE_SERIAL] = ( 321 constants.REMOTE_INSTANCE_ADB_SERIAL % 322 forwarded_ports.adb_port) 323 if unlock_screen: 324 AdbTools(forwarded_ports.adb_port).AutoUnlockScreen() 325 if connect_webrtc and reporter.status == report.Status.SUCCESS: 326 webrtc_local_port = utils.PickFreePort() 327 device_dict[constants.WEBRTC_PORT] = webrtc_local_port 328 utils.EstablishWebRTCSshTunnel( 329 ip_addr=ip, 330 webrtc_local_port=webrtc_local_port, 331 rsa_key_file=(ssh_private_key_path or 332 cfg.ssh_private_key_path), 333 ssh_user=ssh_user, 334 extra_args_ssh_tunnel=cfg.extra_args_ssh_tunnel) 335 if device.instance_name in logs: 336 device_dict[constants.LOGS] = logs[device.instance_name] 337 if device.instance_name in failures: 338 reporter.SetErrorType(constants.ACLOUD_BOOT_UP_ERROR) 339 if device.stage: 340 reporter.SetErrorType(_DICT_ERROR_TYPE[device.stage]) 341 reporter.AddData(key="devices_failing_boot", value=device_dict) 342 reporter.AddError(str(failures[device.instance_name])) 343 else: 344 reporter.AddData(key="devices", value=device_dict) 345 except (errors.DriverError, errors.CheckGCEZonesQuotaError) as e: 346 reporter.SetErrorType(_GetErrorType(e)) 347 reporter.AddError(str(e)) 348 reporter.SetStatus(report.Status.FAIL) 349 return reporter 350