• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2#
3# Copyright (C) 2017 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#      http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18"""Send an A/B update to an Android device over adb."""
19
20from __future__ import print_function
21from __future__ import absolute_import
22
23import argparse
24import binascii
25import hashlib
26import logging
27import os
28import re
29import socket
30import subprocess
31import sys
32import struct
33import tempfile
34import time
35import threading
36import xml.etree.ElementTree
37import zipfile
38
39from six.moves import BaseHTTPServer
40
41import update_payload.payload
42
43
44# The path used to store the OTA package when applying the package from a file.
45OTA_PACKAGE_PATH = '/data/ota_package'
46
47# The path to the payload public key on the device.
48PAYLOAD_KEY_PATH = '/etc/update_engine/update-payload-key.pub.pem'
49
50# The port on the device that update_engine should connect to.
51DEVICE_PORT = 1234
52
53
54def CopyFileObjLength(fsrc, fdst, buffer_size=128 * 1024, copy_length=None, speed_limit=None):
55  """Copy from a file object to another.
56
57  This function is similar to shutil.copyfileobj except that it allows to copy
58  less than the full source file.
59
60  Args:
61    fsrc: source file object where to read from.
62    fdst: destination file object where to write to.
63    buffer_size: size of the copy buffer in memory.
64    copy_length: maximum number of bytes to copy, or None to copy everything.
65    speed_limit: upper limit for copying speed, in bytes per second.
66
67  Returns:
68    the number of bytes copied.
69  """
70  # If buffer size significantly bigger than speed limit
71  # traffic would seem extremely spiky to the client.
72  if speed_limit:
73    print(f"Applying speed limit: {speed_limit}")
74    buffer_size = min(speed_limit//32, buffer_size)
75
76  start_time = time.time()
77  copied = 0
78  while True:
79    chunk_size = buffer_size
80    if copy_length is not None:
81      chunk_size = min(chunk_size, copy_length - copied)
82      if not chunk_size:
83        break
84    buf = fsrc.read(chunk_size)
85    if not buf:
86      break
87    if speed_limit:
88      expected_duration = copied/speed_limit
89      actual_duration = time.time() - start_time
90      if actual_duration < expected_duration:
91        time.sleep(expected_duration-actual_duration)
92    fdst.write(buf)
93    copied += len(buf)
94  return copied
95
96
97class AndroidOTAPackage(object):
98  """Android update payload using the .zip format.
99
100  Android OTA packages traditionally used a .zip file to store the payload. When
101  applying A/B updates over the network, a payload binary is stored RAW inside
102  this .zip file which is used by update_engine to apply the payload. To do
103  this, an offset and size inside the .zip file are provided.
104  """
105
106  # Android OTA package file paths.
107  OTA_PAYLOAD_BIN = 'payload.bin'
108  OTA_PAYLOAD_PROPERTIES_TXT = 'payload_properties.txt'
109  SECONDARY_OTA_PAYLOAD_BIN = 'secondary/payload.bin'
110  SECONDARY_OTA_PAYLOAD_PROPERTIES_TXT = 'secondary/payload_properties.txt'
111  PAYLOAD_MAGIC_HEADER = b'CrAU'
112
113  def __init__(self, otafilename, secondary_payload=False):
114    self.otafilename = otafilename
115
116    otazip = zipfile.ZipFile(otafilename, 'r')
117    payload_entry = (self.SECONDARY_OTA_PAYLOAD_BIN if secondary_payload else
118                     self.OTA_PAYLOAD_BIN)
119    payload_info = otazip.getinfo(payload_entry)
120
121    if payload_info.compress_type != 0:
122      logging.error(
123          "Expected payload to be uncompressed, got compression method %d",
124          payload_info.compress_type)
125    # Don't use len(payload_info.extra). Because that returns size of extra
126    # fields in central directory. We need to look at local file directory,
127    # as these two might have different sizes.
128    with open(otafilename, "rb") as fp:
129      fp.seek(payload_info.header_offset)
130      data = fp.read(zipfile.sizeFileHeader)
131      fheader = struct.unpack(zipfile.structFileHeader, data)
132      # Last two fields of local file header are filename length and
133      # extra length
134      filename_len = fheader[-2]
135      extra_len = fheader[-1]
136      self.offset = payload_info.header_offset
137      self.offset += zipfile.sizeFileHeader
138      self.offset += filename_len + extra_len
139      self.size = payload_info.file_size
140      fp.seek(self.offset)
141      payload_header = fp.read(4)
142      if payload_header != self.PAYLOAD_MAGIC_HEADER:
143        logging.warning(
144            "Invalid header, expected %s, got %s."
145            "Either the offset is not correct, or payload is corrupted",
146            binascii.hexlify(self.PAYLOAD_MAGIC_HEADER),
147            binascii.hexlify(payload_header))
148
149    property_entry = (self.SECONDARY_OTA_PAYLOAD_PROPERTIES_TXT if
150                      secondary_payload else self.OTA_PAYLOAD_PROPERTIES_TXT)
151    self.properties = otazip.read(property_entry)
152
153
154class UpdateHandler(BaseHTTPServer.BaseHTTPRequestHandler):
155  """A HTTPServer that supports single-range requests.
156
157  Attributes:
158    serving_payload: path to the only payload file we are serving.
159    serving_range: the start offset and size tuple of the payload.
160  """
161
162  @staticmethod
163  def _parse_range(range_str, file_size):
164    """Parse an HTTP range string.
165
166    Args:
167      range_str: HTTP Range header in the request, not including "Header:".
168      file_size: total size of the serving file.
169
170    Returns:
171      A tuple (start_range, end_range) with the range of bytes requested.
172    """
173    start_range = 0
174    end_range = file_size
175
176    if range_str:
177      range_str = range_str.split('=', 1)[1]
178      s, e = range_str.split('-', 1)
179      if s:
180        start_range = int(s)
181        if e:
182          end_range = int(e) + 1
183      elif e:
184        if int(e) < file_size:
185          start_range = file_size - int(e)
186    return start_range, end_range
187
188  def do_GET(self):  # pylint: disable=invalid-name
189    """Reply with the requested payload file."""
190    if self.path != '/payload':
191      self.send_error(404, 'Unknown request')
192      return
193
194    if not self.serving_payload:
195      self.send_error(500, 'No serving payload set')
196      return
197
198    try:
199      f = open(self.serving_payload, 'rb')
200    except IOError:
201      self.send_error(404, 'File not found')
202      return
203    # Handle the range request.
204    if 'Range' in self.headers:
205      self.send_response(206)
206    else:
207      self.send_response(200)
208
209    serving_start, serving_size = self.serving_range
210    start_range, end_range = self._parse_range(self.headers.get('range'),
211                                               serving_size)
212    logging.info('Serving request for %s from %s [%d, %d) length: %d',
213                 self.path, self.serving_payload, serving_start + start_range,
214                 serving_start + end_range, end_range - start_range)
215
216    self.send_header('Accept-Ranges', 'bytes')
217    self.send_header('Content-Range',
218                     'bytes ' + str(start_range) + '-' + str(end_range - 1) +
219                     '/' + str(end_range - start_range))
220    self.send_header('Content-Length', end_range - start_range)
221
222    stat = os.fstat(f.fileno())
223    self.send_header('Last-Modified', self.date_time_string(stat.st_mtime))
224    self.send_header('Content-type', 'application/octet-stream')
225    self.end_headers()
226
227    f.seek(serving_start + start_range)
228    CopyFileObjLength(f, self.wfile, copy_length=end_range -
229                      start_range, speed_limit=self.speed_limit)
230
231  def do_POST(self):  # pylint: disable=invalid-name
232    """Reply with the omaha response xml."""
233    if self.path != '/update':
234      self.send_error(404, 'Unknown request')
235      return
236
237    if not self.serving_payload:
238      self.send_error(500, 'No serving payload set')
239      return
240
241    try:
242      f = open(self.serving_payload, 'rb')
243    except IOError:
244      self.send_error(404, 'File not found')
245      return
246
247    content_length = int(self.headers.getheader('Content-Length'))
248    request_xml = self.rfile.read(content_length)
249    xml_root = xml.etree.ElementTree.fromstring(request_xml)
250    appid = None
251    for app in xml_root.iter('app'):
252      if 'appid' in app.attrib:
253        appid = app.attrib['appid']
254        break
255    if not appid:
256      self.send_error(400, 'No appid in Omaha request')
257      return
258
259    self.send_response(200)
260    self.send_header("Content-type", "text/xml")
261    self.end_headers()
262
263    serving_start, serving_size = self.serving_range
264    sha256 = hashlib.sha256()
265    f.seek(serving_start)
266    bytes_to_hash = serving_size
267    while bytes_to_hash:
268      buf = f.read(min(bytes_to_hash, 1024 * 1024))
269      if not buf:
270        self.send_error(500, 'Payload too small')
271        return
272      sha256.update(buf)
273      bytes_to_hash -= len(buf)
274
275    payload = update_payload.Payload(f, payload_file_offset=serving_start)
276    payload.Init()
277
278    response_xml = '''
279        <?xml version="1.0" encoding="UTF-8"?>
280        <response protocol="3.0">
281          <app appid="{appid}">
282            <updatecheck status="ok">
283              <urls>
284                <url codebase="http://127.0.0.1:{port}/"/>
285              </urls>
286              <manifest version="0.0.0.1">
287                <actions>
288                  <action event="install" run="payload"/>
289                  <action event="postinstall" MetadataSize="{metadata_size}"/>
290                </actions>
291                <packages>
292                  <package hash_sha256="{payload_hash}" name="payload" size="{payload_size}"/>
293                </packages>
294              </manifest>
295            </updatecheck>
296          </app>
297        </response>
298    '''.format(appid=appid, port=DEVICE_PORT,
299               metadata_size=payload.metadata_size,
300               payload_hash=sha256.hexdigest(),
301               payload_size=serving_size)
302    self.wfile.write(response_xml.strip())
303    return
304
305
306class ServerThread(threading.Thread):
307  """A thread for serving HTTP requests."""
308
309  def __init__(self, ota_filename, serving_range, speed_limit):
310    threading.Thread.__init__(self)
311    # serving_payload and serving_range are class attributes and the
312    # UpdateHandler class is instantiated with every request.
313    UpdateHandler.serving_payload = ota_filename
314    UpdateHandler.serving_range = serving_range
315    UpdateHandler.speed_limit = speed_limit
316    self._httpd = BaseHTTPServer.HTTPServer(('127.0.0.1', 0), UpdateHandler)
317    self.port = self._httpd.server_port
318
319  def run(self):
320    try:
321      self._httpd.serve_forever()
322    except (KeyboardInterrupt, socket.error):
323      pass
324    logging.info('Server Terminated')
325
326  def StopServer(self):
327    self._httpd.shutdown()
328    self._httpd.socket.close()
329
330
331def StartServer(ota_filename, serving_range, speed_limit):
332  t = ServerThread(ota_filename, serving_range, speed_limit)
333  t.start()
334  return t
335
336
337def AndroidUpdateCommand(ota_filename, secondary, payload_url, extra_headers):
338  """Return the command to run to start the update in the Android device."""
339  ota = AndroidOTAPackage(ota_filename, secondary)
340  headers = ota.properties
341  headers += b'USER_AGENT=Dalvik (something, something)\n'
342  headers += b'NETWORK_ID=0\n'
343  headers += extra_headers.encode()
344
345  return ['update_engine_client', '--update', '--follow',
346          '--payload=%s' % payload_url, '--offset=%d' % ota.offset,
347          '--size=%d' % ota.size, '--headers="%s"' % headers.decode()]
348
349
350def OmahaUpdateCommand(omaha_url):
351  """Return the command to run to start the update in a device using Omaha."""
352  return ['update_engine_client', '--update', '--follow',
353          '--omaha_url=%s' % omaha_url]
354
355
356class AdbHost(object):
357  """Represents a device connected via ADB."""
358
359  def __init__(self, device_serial=None):
360    """Construct an instance.
361
362    Args:
363        device_serial: options string serial number of attached device.
364    """
365    self._device_serial = device_serial
366    self._command_prefix = ['adb']
367    if self._device_serial:
368      self._command_prefix += ['-s', self._device_serial]
369
370  def adb(self, command, timeout_seconds: float = None):
371    """Run an ADB command like "adb push".
372
373    Args:
374      command: list of strings containing command and arguments to run
375
376    Returns:
377      the program's return code.
378
379    Raises:
380      subprocess.CalledProcessError on command exit != 0.
381    """
382    command = self._command_prefix + command
383    logging.info('Running: %s', ' '.join(str(x) for x in command))
384    p = subprocess.Popen(command, universal_newlines=True)
385    p.wait(timeout_seconds)
386    return p.returncode
387
388  def adb_output(self, command):
389    """Run an ADB command like "adb push" and return the output.
390
391    Args:
392      command: list of strings containing command and arguments to run
393
394    Returns:
395      the program's output as a string.
396
397    Raises:
398      subprocess.CalledProcessError on command exit != 0.
399    """
400    command = self._command_prefix + command
401    logging.info('Running: %s', ' '.join(str(x) for x in command))
402    return subprocess.check_output(command, universal_newlines=True)
403
404
405def PushMetadata(dut, otafile, metadata_path):
406  payload = update_payload.Payload(otafile)
407  payload.Init()
408  with tempfile.TemporaryDirectory() as tmpdir:
409    with zipfile.ZipFile(otafile, "r") as zfp:
410      extracted_path = os.path.join(tmpdir, "payload.bin")
411      with zfp.open("payload.bin") as payload_fp, \
412              open(extracted_path, "wb") as output_fp:
413          # Only extract the first |data_offset| bytes from the payload.
414          # This is because allocateSpaceForPayload only needs to see
415          # the manifest, not the entire payload.
416          # Extracting the entire payload works, but is slow for full
417          # OTA.
418        output_fp.write(payload_fp.read(payload.data_offset))
419
420      return dut.adb([
421          "push",
422          extracted_path,
423          metadata_path
424      ]) == 0
425
426
427def ParseSpeedLimit(arg: str) -> int:
428  arg = arg.strip().upper()
429  if not re.match(r"\d+[KkMmGgTt]?", arg):
430    raise argparse.ArgumentError(
431        "Wrong speed limit format, expected format is number followed by unit, such as 10K, 5m, 3G (case insensitive)")
432  unit = 1
433  if arg[-1].isalpha():
434    if arg[-1] == "K":
435      unit = 1024
436    elif arg[-1] == "M":
437      unit = 1024 * 1024
438    elif arg[-1] == "G":
439      unit = 1024 * 1024 * 1024
440    elif arg[-1] == "T":
441      unit = 1024 * 1024 * 1024 * 1024
442    else:
443      raise argparse.ArgumentError(
444          f"Unsupported unit for download speed: {arg[-1]}, supported units are K,M,G,T (case insensitive)")
445  return int(float(arg[:-1]) * unit)
446
447
448def main():
449  parser = argparse.ArgumentParser(description='Android A/B OTA helper.')
450  parser.add_argument('otafile', metavar='PAYLOAD', type=str,
451                      help='the OTA package file (a .zip file) or raw payload \
452                      if device uses Omaha.')
453  parser.add_argument('--file', action='store_true',
454                      help='Push the file to the device before updating.')
455  parser.add_argument('--no-push', action='store_true',
456                      help='Skip the "push" command when using --file')
457  parser.add_argument('-s', type=str, default='', metavar='DEVICE',
458                      help='The specific device to use.')
459  parser.add_argument('--no-verbose', action='store_true',
460                      help='Less verbose output')
461  parser.add_argument('--public-key', type=str, default='',
462                      help='Override the public key used to verify payload.')
463  parser.add_argument('--extra-headers', type=str, default='',
464                      help='Extra headers to pass to the device.')
465  parser.add_argument('--secondary', action='store_true',
466                      help='Update with the secondary payload in the package.')
467  parser.add_argument('--no-slot-switch', action='store_true',
468                      help='Do not perform slot switch after the update.')
469  parser.add_argument('--no-postinstall', action='store_true',
470                      help='Do not execute postinstall scripts after the update.')
471  parser.add_argument('--allocate-only', action='store_true',
472                      help='Allocate space for this OTA, instead of actually \
473                        applying the OTA.')
474  parser.add_argument('--verify-only', action='store_true',
475                      help='Verify metadata then exit, instead of applying the OTA.')
476  parser.add_argument('--no-care-map', action='store_true',
477                      help='Do not push care_map.pb to device.')
478  parser.add_argument('--perform-slot-switch', action='store_true',
479                      help='Perform slot switch for this OTA package')
480  parser.add_argument('--perform-reset-slot-switch', action='store_true',
481                      help='Perform reset slot switch for this OTA package')
482  parser.add_argument('--wipe-user-data', action='store_true',
483                      help='Wipe userdata after installing OTA')
484  parser.add_argument('--vabc-none', action='store_true',
485                      help='Set Virtual AB Compression algorithm to none, but still use Android COW format')
486  parser.add_argument('--disable-vabc', action='store_true',
487                      help='Option to enable or disable vabc. If set to false, will fall back on A/B')
488  parser.add_argument('--enable-threading', action='store_true',
489                      help='Enable multi-threaded compression for VABC')
490  parser.add_argument('--batched-writes', action='store_true',
491                      help='Enable batched writes for VABC')
492  parser.add_argument('--speed-limit', type=str,
493                      help='Speed limit for serving payloads over HTTP. For '
494                      'example: 10K, 5m, 1G, input is case insensitive')
495
496  args = parser.parse_args()
497  if args.speed_limit:
498    args.speed_limit = ParseSpeedLimit(args.speed_limit)
499
500  logging.basicConfig(
501      level=logging.WARNING if args.no_verbose else logging.INFO)
502
503  start_time = time.perf_counter()
504
505  dut = AdbHost(args.s)
506
507  server_thread = None
508  # List of commands to execute on exit.
509  finalize_cmds = []
510  # Commands to execute when canceling an update.
511  cancel_cmd = ['shell', 'su', '0', 'update_engine_client', '--cancel']
512  # List of commands to perform the update.
513  cmds = []
514
515  help_cmd = ['shell', 'su', '0', 'update_engine_client', '--help']
516  use_omaha = 'omaha' in dut.adb_output(help_cmd)
517
518  metadata_path = "/data/ota_package/metadata"
519  if args.allocate_only:
520    if PushMetadata(dut, args.otafile, metadata_path):
521      dut.adb([
522          "shell", "update_engine_client", "--allocate",
523          "--metadata={}".format(metadata_path)])
524    # Return 0, as we are executing ADB commands here, no work needed after
525    # this point
526    return 0
527  if args.verify_only:
528    if PushMetadata(dut, args.otafile, metadata_path):
529      dut.adb([
530          "shell", "update_engine_client", "--verify",
531          "--metadata={}".format(metadata_path)])
532    # Return 0, as we are executing ADB commands here, no work needed after
533    # this point
534    return 0
535  if args.perform_slot_switch:
536    assert PushMetadata(dut, args.otafile, metadata_path)
537    dut.adb(["shell", "update_engine_client",
538            "--switch_slot=true", "--metadata={}".format(metadata_path), "--follow"])
539    return 0
540  if args.perform_reset_slot_switch:
541    assert PushMetadata(dut, args.otafile, metadata_path)
542    dut.adb(["shell", "update_engine_client",
543            "--switch_slot=false", "--metadata={}".format(metadata_path)])
544    return 0
545
546  if args.no_slot_switch:
547    args.extra_headers += "\nSWITCH_SLOT_ON_REBOOT=0"
548  if args.no_postinstall:
549    args.extra_headers += "\nRUN_POST_INSTALL=0"
550  if args.wipe_user_data:
551    args.extra_headers += "\nPOWERWASH=1"
552  if args.vabc_none:
553    args.extra_headers += "\nVABC_NONE=1"
554  if args.disable_vabc:
555    args.extra_headers += "\nDISABLE_VABC=1"
556  if args.enable_threading:
557    args.extra_headers += "\nENABLE_THREADING=1"
558  if args.batched_writes:
559    args.extra_headers += "\nBATCHED_WRITES=1"
560
561  with zipfile.ZipFile(args.otafile) as zfp:
562    CARE_MAP_ENTRY_NAME = "care_map.pb"
563    if CARE_MAP_ENTRY_NAME in zfp.namelist() and not args.no_care_map:
564      # Need root permission to push to /data
565      dut.adb(["root"])
566      with tempfile.NamedTemporaryFile() as care_map_fp:
567        care_map_fp.write(zfp.read(CARE_MAP_ENTRY_NAME))
568        care_map_fp.flush()
569        dut.adb(["push", care_map_fp.name,
570                "/data/ota_package/" + CARE_MAP_ENTRY_NAME])
571
572  if args.file:
573    # Update via pushing a file to /data.
574    device_ota_file = os.path.join(OTA_PACKAGE_PATH, 'debug.zip')
575    payload_url = 'file://' + device_ota_file
576    if not args.no_push:
577      data_local_tmp_file = '/data/local/tmp/debug.zip'
578      cmds.append(['push', args.otafile, data_local_tmp_file])
579      cmds.append(['shell', 'su', '0', 'mv', data_local_tmp_file,
580                   device_ota_file])
581      cmds.append(['shell', 'su', '0', 'chcon',
582                   'u:object_r:ota_package_file:s0', device_ota_file])
583    cmds.append(['shell', 'su', '0', 'chown', 'system:cache', device_ota_file])
584    cmds.append(['shell', 'su', '0', 'chmod', '0660', device_ota_file])
585  else:
586    # Update via sending the payload over the network with an "adb reverse"
587    # command.
588    payload_url = 'http://127.0.0.1:%d/payload' % DEVICE_PORT
589    if use_omaha and zipfile.is_zipfile(args.otafile):
590      ota = AndroidOTAPackage(args.otafile, args.secondary)
591      serving_range = (ota.offset, ota.size)
592    else:
593      serving_range = (0, os.stat(args.otafile).st_size)
594    server_thread = StartServer(args.otafile, serving_range, args.speed_limit)
595    cmds.append(
596        ['reverse', 'tcp:%d' % DEVICE_PORT, 'tcp:%d' % server_thread.port])
597    finalize_cmds.append(['reverse', '--remove', 'tcp:%d' % DEVICE_PORT])
598
599  if args.public_key:
600    payload_key_dir = os.path.dirname(PAYLOAD_KEY_PATH)
601    cmds.append(
602        ['shell', 'su', '0', 'mount', '-t', 'tmpfs', 'tmpfs', payload_key_dir])
603    # Allow adb push to payload_key_dir
604    cmds.append(['shell', 'su', '0', 'chcon', 'u:object_r:shell_data_file:s0',
605                 payload_key_dir])
606    cmds.append(['push', args.public_key, PAYLOAD_KEY_PATH])
607    # Allow update_engine to read it.
608    cmds.append(['shell', 'su', '0', 'chcon', '-R', 'u:object_r:system_file:s0',
609                 payload_key_dir])
610    finalize_cmds.append(['shell', 'su', '0', 'umount', payload_key_dir])
611
612  try:
613    # The main update command using the configured payload_url.
614    if use_omaha:
615      update_cmd = \
616          OmahaUpdateCommand('http://127.0.0.1:%d/update' % DEVICE_PORT)
617    else:
618      update_cmd = AndroidUpdateCommand(args.otafile, args.secondary,
619                                        payload_url, args.extra_headers)
620    cmds.append(['shell', 'su', '0'] + update_cmd)
621
622    for cmd in cmds:
623      dut.adb(cmd)
624  except KeyboardInterrupt:
625    dut.adb(cancel_cmd)
626  finally:
627    if server_thread:
628      server_thread.StopServer()
629    for cmd in finalize_cmds:
630      dut.adb(cmd, 5)
631
632  logging.info('Update took %.3f seconds', (time.perf_counter() - start_time))
633  return 0
634
635
636if __name__ == '__main__':
637  sys.exit(main())
638