• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#
2# Copyright (C) 2013 The Android Open Source Project
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#      http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16
17"""Applying a Chrome OS update payload.
18
19This module is used internally by the main Payload class for applying an update
20payload. The interface for invoking the applier is as follows:
21
22  applier = PayloadApplier(payload)
23  applier.Run(...)
24
25"""
26
27from __future__ import print_function
28
29import array
30import bz2
31import hashlib
32import itertools
33# Not everywhere we can have the lzma library so we ignore it if we didn't have
34# it because it is not going to be used. For example, 'cros flash' uses
35# devserver code which eventually loads this file, but the lzma library is not
36# included in the client test devices, and it is not necessary to do so. But
37# lzma is not used in 'cros flash' so it should be fine. Python 3.x include
38# lzma, but for backward compatibility with Python 2.7, backports-lzma is
39# needed.
40try:
41  import lzma
42except ImportError:
43  try:
44    from backports import lzma
45  except ImportError:
46    pass
47import os
48import shutil
49import subprocess
50import sys
51import tempfile
52
53from update_payload import common
54from update_payload.error import PayloadError
55
56
57#
58# Helper functions.
59#
60def _VerifySha256(file_obj, expected_hash, name, length=-1):
61  """Verifies the SHA256 hash of a file.
62
63  Args:
64    file_obj: file object to read
65    expected_hash: the hash digest we expect to be getting
66    name: name string of this hash, for error reporting
67    length: precise length of data to verify (optional)
68
69  Raises:
70    PayloadError if computed hash doesn't match expected one, or if fails to
71    read the specified length of data.
72  """
73  hasher = hashlib.sha256()
74  block_length = 1024 * 1024
75  max_length = length if length >= 0 else sys.maxint
76
77  while max_length > 0:
78    read_length = min(max_length, block_length)
79    data = file_obj.read(read_length)
80    if not data:
81      break
82    max_length -= len(data)
83    hasher.update(data)
84
85  if length >= 0 and max_length > 0:
86    raise PayloadError(
87        'insufficient data (%d instead of %d) when verifying %s' %
88        (length - max_length, length, name))
89
90  actual_hash = hasher.digest()
91  if actual_hash != expected_hash:
92    raise PayloadError('%s hash (%s) not as expected (%s)' %
93                       (name, common.FormatSha256(actual_hash),
94                        common.FormatSha256(expected_hash)))
95
96
97def _ReadExtents(file_obj, extents, block_size, max_length=-1):
98  """Reads data from file as defined by extent sequence.
99
100  This tries to be efficient by not copying data as it is read in chunks.
101
102  Args:
103    file_obj: file object
104    extents: sequence of block extents (offset and length)
105    block_size: size of each block
106    max_length: maximum length to read (optional)
107
108  Returns:
109    A character array containing the concatenated read data.
110  """
111  data = array.array('c')
112  if max_length < 0:
113    max_length = sys.maxint
114  for ex in extents:
115    if max_length == 0:
116      break
117    read_length = min(max_length, ex.num_blocks * block_size)
118
119    # Fill with zeros or read from file, depending on the type of extent.
120    if ex.start_block == common.PSEUDO_EXTENT_MARKER:
121      data.extend(itertools.repeat('\0', read_length))
122    else:
123      file_obj.seek(ex.start_block * block_size)
124      data.fromfile(file_obj, read_length)
125
126    max_length -= read_length
127
128  return data
129
130
131def _WriteExtents(file_obj, data, extents, block_size, base_name):
132  """Writes data to file as defined by extent sequence.
133
134  This tries to be efficient by not copy data as it is written in chunks.
135
136  Args:
137    file_obj: file object
138    data: data to write
139    extents: sequence of block extents (offset and length)
140    block_size: size of each block
141    base_name: name string of extent sequence for error reporting
142
143  Raises:
144    PayloadError when things don't add up.
145  """
146  data_offset = 0
147  data_length = len(data)
148  for ex, ex_name in common.ExtentIter(extents, base_name):
149    if not data_length:
150      raise PayloadError('%s: more write extents than data' % ex_name)
151    write_length = min(data_length, ex.num_blocks * block_size)
152
153    # Only do actual writing if this is not a pseudo-extent.
154    if ex.start_block != common.PSEUDO_EXTENT_MARKER:
155      file_obj.seek(ex.start_block * block_size)
156      data_view = buffer(data, data_offset, write_length)
157      file_obj.write(data_view)
158
159    data_offset += write_length
160    data_length -= write_length
161
162  if data_length:
163    raise PayloadError('%s: more data than write extents' % base_name)
164
165
166def _ExtentsToBspatchArg(extents, block_size, base_name, data_length=-1):
167  """Translates an extent sequence into a bspatch-compatible string argument.
168
169  Args:
170    extents: sequence of block extents (offset and length)
171    block_size: size of each block
172    base_name: name string of extent sequence for error reporting
173    data_length: the actual total length of the data in bytes (optional)
174
175  Returns:
176    A tuple consisting of (i) a string of the form
177    "off_1:len_1,...,off_n:len_n", (ii) an offset where zero padding is needed
178    for filling the last extent, (iii) the length of the padding (zero means no
179    padding is needed and the extents cover the full length of data).
180
181  Raises:
182    PayloadError if data_length is too short or too long.
183  """
184  arg = ''
185  pad_off = pad_len = 0
186  if data_length < 0:
187    data_length = sys.maxint
188  for ex, ex_name in common.ExtentIter(extents, base_name):
189    if not data_length:
190      raise PayloadError('%s: more extents than total data length' % ex_name)
191
192    is_pseudo = ex.start_block == common.PSEUDO_EXTENT_MARKER
193    start_byte = -1 if is_pseudo else ex.start_block * block_size
194    num_bytes = ex.num_blocks * block_size
195    if data_length < num_bytes:
196      # We're only padding a real extent.
197      if not is_pseudo:
198        pad_off = start_byte + data_length
199        pad_len = num_bytes - data_length
200
201      num_bytes = data_length
202
203    arg += '%s%d:%d' % (arg and ',', start_byte, num_bytes)
204    data_length -= num_bytes
205
206  if data_length:
207    raise PayloadError('%s: extents not covering full data length' % base_name)
208
209  return arg, pad_off, pad_len
210
211
212#
213# Payload application.
214#
215class PayloadApplier(object):
216  """Applying an update payload.
217
218  This is a short-lived object whose purpose is to isolate the logic used for
219  applying an update payload.
220  """
221
222  def __init__(self, payload, bsdiff_in_place=True, bspatch_path=None,
223               puffpatch_path=None, truncate_to_expected_size=True):
224    """Initialize the applier.
225
226    Args:
227      payload: the payload object to check
228      bsdiff_in_place: whether to perform BSDIFF operation in-place (optional)
229      bspatch_path: path to the bspatch binary (optional)
230      puffpatch_path: path to the puffpatch binary (optional)
231      truncate_to_expected_size: whether to truncate the resulting partitions
232                                 to their expected sizes, as specified in the
233                                 payload (optional)
234    """
235    assert payload.is_init, 'uninitialized update payload'
236    self.payload = payload
237    self.block_size = payload.manifest.block_size
238    self.minor_version = payload.manifest.minor_version
239    self.bsdiff_in_place = bsdiff_in_place
240    self.bspatch_path = bspatch_path or 'bspatch'
241    self.puffpatch_path = puffpatch_path or 'puffin'
242    self.truncate_to_expected_size = truncate_to_expected_size
243
244  def _ApplyReplaceOperation(self, op, op_name, out_data, part_file, part_size):
245    """Applies a REPLACE{,_BZ,_XZ} operation.
246
247    Args:
248      op: the operation object
249      op_name: name string for error reporting
250      out_data: the data to be written
251      part_file: the partition file object
252      part_size: the size of the partition
253
254    Raises:
255      PayloadError if something goes wrong.
256    """
257    block_size = self.block_size
258    data_length = len(out_data)
259
260    # Decompress data if needed.
261    if op.type == common.OpType.REPLACE_BZ:
262      out_data = bz2.decompress(out_data)
263      data_length = len(out_data)
264    elif op.type == common.OpType.REPLACE_XZ:
265      # pylint: disable=no-member
266      out_data = lzma.decompress(out_data)
267      data_length = len(out_data)
268
269    # Write data to blocks specified in dst extents.
270    data_start = 0
271    for ex, ex_name in common.ExtentIter(op.dst_extents,
272                                         '%s.dst_extents' % op_name):
273      start_block = ex.start_block
274      num_blocks = ex.num_blocks
275      count = num_blocks * block_size
276
277      # Make sure it's not a fake (signature) operation.
278      if start_block != common.PSEUDO_EXTENT_MARKER:
279        data_end = data_start + count
280
281        # Make sure we're not running past partition boundary.
282        if (start_block + num_blocks) * block_size > part_size:
283          raise PayloadError(
284              '%s: extent (%s) exceeds partition size (%d)' %
285              (ex_name, common.FormatExtent(ex, block_size),
286               part_size))
287
288        # Make sure that we have enough data to write.
289        if data_end >= data_length + block_size:
290          raise PayloadError(
291              '%s: more dst blocks than data (even with padding)')
292
293        # Pad with zeros if necessary.
294        if data_end > data_length:
295          padding = data_end - data_length
296          out_data += '\0' * padding
297
298        self.payload.payload_file.seek(start_block * block_size)
299        part_file.seek(start_block * block_size)
300        part_file.write(out_data[data_start:data_end])
301
302      data_start += count
303
304    # Make sure we wrote all data.
305    if data_start < data_length:
306      raise PayloadError('%s: wrote fewer bytes (%d) than expected (%d)' %
307                         (op_name, data_start, data_length))
308
309  def _ApplyMoveOperation(self, op, op_name, part_file):
310    """Applies a MOVE operation.
311
312    Note that this operation must read the whole block data from the input and
313    only then dump it, due to our in-place update semantics; otherwise, it
314    might clobber data midway through.
315
316    Args:
317      op: the operation object
318      op_name: name string for error reporting
319      part_file: the partition file object
320
321    Raises:
322      PayloadError if something goes wrong.
323    """
324    block_size = self.block_size
325
326    # Gather input raw data from src extents.
327    in_data = _ReadExtents(part_file, op.src_extents, block_size)
328
329    # Dump extracted data to dst extents.
330    _WriteExtents(part_file, in_data, op.dst_extents, block_size,
331                  '%s.dst_extents' % op_name)
332
333  def _ApplyZeroOperation(self, op, op_name, part_file):
334    """Applies a ZERO operation.
335
336    Args:
337      op: the operation object
338      op_name: name string for error reporting
339      part_file: the partition file object
340
341    Raises:
342      PayloadError if something goes wrong.
343    """
344    block_size = self.block_size
345    base_name = '%s.dst_extents' % op_name
346
347    # Iterate over the extents and write zero.
348    # pylint: disable=unused-variable
349    for ex, ex_name in common.ExtentIter(op.dst_extents, base_name):
350      # Only do actual writing if this is not a pseudo-extent.
351      if ex.start_block != common.PSEUDO_EXTENT_MARKER:
352        part_file.seek(ex.start_block * block_size)
353        part_file.write('\0' * (ex.num_blocks * block_size))
354
355  def _ApplySourceCopyOperation(self, op, op_name, old_part_file,
356                                new_part_file):
357    """Applies a SOURCE_COPY operation.
358
359    Args:
360      op: the operation object
361      op_name: name string for error reporting
362      old_part_file: the old partition file object
363      new_part_file: the new partition file object
364
365    Raises:
366      PayloadError if something goes wrong.
367    """
368    if not old_part_file:
369      raise PayloadError(
370          '%s: no source partition file provided for operation type (%d)' %
371          (op_name, op.type))
372
373    block_size = self.block_size
374
375    # Gather input raw data from src extents.
376    in_data = _ReadExtents(old_part_file, op.src_extents, block_size)
377
378    # Dump extracted data to dst extents.
379    _WriteExtents(new_part_file, in_data, op.dst_extents, block_size,
380                  '%s.dst_extents' % op_name)
381
382  def _BytesInExtents(self, extents, base_name):
383    """Counts the length of extents in bytes.
384
385    Args:
386      extents: The list of Extents.
387      base_name: For error reporting.
388
389    Returns:
390      The number of bytes in extents.
391    """
392
393    length = 0
394    # pylint: disable=unused-variable
395    for ex, ex_name in common.ExtentIter(extents, base_name):
396      length += ex.num_blocks * self.block_size
397    return length
398
399  def _ApplyDiffOperation(self, op, op_name, patch_data, old_part_file,
400                          new_part_file):
401    """Applies a SOURCE_BSDIFF, BROTLI_BSDIFF or PUFFDIFF operation.
402
403    Args:
404      op: the operation object
405      op_name: name string for error reporting
406      patch_data: the binary patch content
407      old_part_file: the source partition file object
408      new_part_file: the target partition file object
409
410    Raises:
411      PayloadError if something goes wrong.
412    """
413    if not old_part_file:
414      raise PayloadError(
415          '%s: no source partition file provided for operation type (%d)' %
416          (op_name, op.type))
417
418    block_size = self.block_size
419
420    # Dump patch data to file.
421    with tempfile.NamedTemporaryFile(delete=False) as patch_file:
422      patch_file_name = patch_file.name
423      patch_file.write(patch_data)
424
425    if (hasattr(new_part_file, 'fileno') and
426        ((not old_part_file) or hasattr(old_part_file, 'fileno'))):
427      # Construct input and output extents argument for bspatch.
428
429      in_extents_arg, _, _ = _ExtentsToBspatchArg(
430          op.src_extents, block_size, '%s.src_extents' % op_name,
431          data_length=op.src_length if op.src_length else
432          self._BytesInExtents(op.src_extents, "%s.src_extents"))
433      out_extents_arg, pad_off, pad_len = _ExtentsToBspatchArg(
434          op.dst_extents, block_size, '%s.dst_extents' % op_name,
435          data_length=op.dst_length if op.dst_length else
436          self._BytesInExtents(op.dst_extents, "%s.dst_extents"))
437
438      new_file_name = '/dev/fd/%d' % new_part_file.fileno()
439      # Diff from source partition.
440      old_file_name = '/dev/fd/%d' % old_part_file.fileno()
441
442      if op.type in (common.OpType.BSDIFF, common.OpType.SOURCE_BSDIFF,
443                     common.OpType.BROTLI_BSDIFF):
444        # Invoke bspatch on partition file with extents args.
445        bspatch_cmd = [self.bspatch_path, old_file_name, new_file_name,
446                       patch_file_name, in_extents_arg, out_extents_arg]
447        subprocess.check_call(bspatch_cmd)
448      elif op.type == common.OpType.PUFFDIFF:
449        # Invoke puffpatch on partition file with extents args.
450        puffpatch_cmd = [self.puffpatch_path,
451                         "--operation=puffpatch",
452                         "--src_file=%s" % old_file_name,
453                         "--dst_file=%s" % new_file_name,
454                         "--patch_file=%s" % patch_file_name,
455                         "--src_extents=%s" % in_extents_arg,
456                         "--dst_extents=%s" % out_extents_arg]
457        subprocess.check_call(puffpatch_cmd)
458      else:
459        raise PayloadError("Unknown operation %s", op.type)
460
461      # Pad with zeros past the total output length.
462      if pad_len:
463        new_part_file.seek(pad_off)
464        new_part_file.write('\0' * pad_len)
465    else:
466      # Gather input raw data and write to a temp file.
467      input_part_file = old_part_file if old_part_file else new_part_file
468      in_data = _ReadExtents(input_part_file, op.src_extents, block_size,
469                             max_length=op.src_length if op.src_length else
470                             self._BytesInExtents(op.src_extents,
471                                                  "%s.src_extents"))
472      with tempfile.NamedTemporaryFile(delete=False) as in_file:
473        in_file_name = in_file.name
474        in_file.write(in_data)
475
476      # Allocate temporary output file.
477      with tempfile.NamedTemporaryFile(delete=False) as out_file:
478        out_file_name = out_file.name
479
480      if op.type in (common.OpType.BSDIFF, common.OpType.SOURCE_BSDIFF,
481                     common.OpType.BROTLI_BSDIFF):
482        # Invoke bspatch.
483        bspatch_cmd = [self.bspatch_path, in_file_name, out_file_name,
484                       patch_file_name]
485        subprocess.check_call(bspatch_cmd)
486      elif op.type == common.OpType.PUFFDIFF:
487        # Invoke puffpatch.
488        puffpatch_cmd = [self.puffpatch_path,
489                         "--operation=puffpatch",
490                         "--src_file=%s" % in_file_name,
491                         "--dst_file=%s" % out_file_name,
492                         "--patch_file=%s" % patch_file_name]
493        subprocess.check_call(puffpatch_cmd)
494      else:
495        raise PayloadError("Unknown operation %s", op.type)
496
497      # Read output.
498      with open(out_file_name, 'rb') as out_file:
499        out_data = out_file.read()
500        if len(out_data) != op.dst_length:
501          raise PayloadError(
502              '%s: actual patched data length (%d) not as expected (%d)' %
503              (op_name, len(out_data), op.dst_length))
504
505      # Write output back to partition, with padding.
506      unaligned_out_len = len(out_data) % block_size
507      if unaligned_out_len:
508        out_data += '\0' * (block_size - unaligned_out_len)
509      _WriteExtents(new_part_file, out_data, op.dst_extents, block_size,
510                    '%s.dst_extents' % op_name)
511
512      # Delete input/output files.
513      os.remove(in_file_name)
514      os.remove(out_file_name)
515
516    # Delete patch file.
517    os.remove(patch_file_name)
518
519  def _ApplyOperations(self, operations, base_name, old_part_file,
520                       new_part_file, part_size):
521    """Applies a sequence of update operations to a partition.
522
523    This assumes an in-place update semantics for MOVE and BSDIFF, namely all
524    reads are performed first, then the data is processed and written back to
525    the same file.
526
527    Args:
528      operations: the sequence of operations
529      base_name: the name of the operation sequence
530      old_part_file: the old partition file object, open for reading/writing
531      new_part_file: the new partition file object, open for reading/writing
532      part_size: the partition size
533
534    Raises:
535      PayloadError if anything goes wrong while processing the payload.
536    """
537    for op, op_name in common.OperationIter(operations, base_name):
538      # Read data blob.
539      data = self.payload.ReadDataBlob(op.data_offset, op.data_length)
540
541      if op.type in (common.OpType.REPLACE, common.OpType.REPLACE_BZ,
542                     common.OpType.REPLACE_XZ):
543        self._ApplyReplaceOperation(op, op_name, data, new_part_file, part_size)
544      elif op.type == common.OpType.MOVE:
545        self._ApplyMoveOperation(op, op_name, new_part_file)
546      elif op.type == common.OpType.ZERO:
547        self._ApplyZeroOperation(op, op_name, new_part_file)
548      elif op.type == common.OpType.BSDIFF:
549        self._ApplyDiffOperation(op, op_name, data, new_part_file,
550                                 new_part_file)
551      elif op.type == common.OpType.SOURCE_COPY:
552        self._ApplySourceCopyOperation(op, op_name, old_part_file,
553                                       new_part_file)
554      elif op.type in (common.OpType.SOURCE_BSDIFF, common.OpType.PUFFDIFF,
555                       common.OpType.BROTLI_BSDIFF):
556        self._ApplyDiffOperation(op, op_name, data, old_part_file,
557                                 new_part_file)
558      else:
559        raise PayloadError('%s: unknown operation type (%d)' %
560                           (op_name, op.type))
561
562  def _ApplyToPartition(self, operations, part_name, base_name,
563                        new_part_file_name, new_part_info,
564                        old_part_file_name=None, old_part_info=None):
565    """Applies an update to a partition.
566
567    Args:
568      operations: the sequence of update operations to apply
569      part_name: the name of the partition, for error reporting
570      base_name: the name of the operation sequence
571      new_part_file_name: file name to write partition data to
572      new_part_info: size and expected hash of dest partition
573      old_part_file_name: file name of source partition (optional)
574      old_part_info: size and expected hash of source partition (optional)
575
576    Raises:
577      PayloadError if anything goes wrong with the update.
578    """
579    # Do we have a source partition?
580    if old_part_file_name:
581      # Verify the source partition.
582      with open(old_part_file_name, 'rb') as old_part_file:
583        _VerifySha256(old_part_file, old_part_info.hash,
584                      'old ' + part_name, length=old_part_info.size)
585      new_part_file_mode = 'r+b'
586      if self.minor_version == common.INPLACE_MINOR_PAYLOAD_VERSION:
587        # Copy the src partition to the dst one; make sure we don't truncate it.
588        shutil.copyfile(old_part_file_name, new_part_file_name)
589      elif self.minor_version >= common.SOURCE_MINOR_PAYLOAD_VERSION:
590        # In minor version >= 2, we don't want to copy the partitions, so
591        # instead just make the new partition file.
592        open(new_part_file_name, 'w').close()
593      else:
594        raise PayloadError("Unknown minor version: %d" % self.minor_version)
595    else:
596      # We need to create/truncate the dst partition file.
597      new_part_file_mode = 'w+b'
598
599    # Apply operations.
600    with open(new_part_file_name, new_part_file_mode) as new_part_file:
601      old_part_file = (open(old_part_file_name, 'r+b')
602                       if old_part_file_name else None)
603      try:
604        self._ApplyOperations(operations, base_name, old_part_file,
605                              new_part_file, new_part_info.size)
606      finally:
607        if old_part_file:
608          old_part_file.close()
609
610      # Truncate the result, if so instructed.
611      if self.truncate_to_expected_size:
612        new_part_file.seek(0, 2)
613        if new_part_file.tell() > new_part_info.size:
614          new_part_file.seek(new_part_info.size)
615          new_part_file.truncate()
616
617    # Verify the resulting partition.
618    with open(new_part_file_name, 'rb') as new_part_file:
619      _VerifySha256(new_part_file, new_part_info.hash,
620                    'new ' + part_name, length=new_part_info.size)
621
622  def Run(self, new_parts, old_parts=None):
623    """Applier entry point, invoking all update operations.
624
625    Args:
626      new_parts: map of partition name to dest partition file
627      old_parts: map of partition name to source partition file (optional)
628
629    Raises:
630      PayloadError if payload application failed.
631    """
632    if old_parts is None:
633      old_parts = {}
634
635    self.payload.ResetFile()
636
637    new_part_info = {}
638    old_part_info = {}
639    install_operations = []
640
641    manifest = self.payload.manifest
642    if self.payload.header.version == 1:
643      for real_name, proto_name in common.CROS_PARTITIONS:
644        new_part_info[real_name] = getattr(manifest, 'new_%s_info' % proto_name)
645        old_part_info[real_name] = getattr(manifest, 'old_%s_info' % proto_name)
646
647      install_operations.append((common.ROOTFS, manifest.install_operations))
648      install_operations.append((common.KERNEL,
649                                 manifest.kernel_install_operations))
650    else:
651      for part in manifest.partitions:
652        name = part.partition_name
653        new_part_info[name] = part.new_partition_info
654        old_part_info[name] = part.old_partition_info
655        install_operations.append((name, part.operations))
656
657    part_names = set(new_part_info.keys())  # Equivalently, old_part_info.keys()
658
659    # Make sure the arguments are sane and match the payload.
660    new_part_names = set(new_parts.keys())
661    if new_part_names != part_names:
662      raise PayloadError('missing dst partition(s) %s' %
663                         ', '.join(part_names - new_part_names))
664
665    old_part_names = set(old_parts.keys())
666    if part_names - old_part_names:
667      if self.payload.IsDelta():
668        raise PayloadError('trying to apply a delta update without src '
669                           'partition(s) %s' %
670                           ', '.join(part_names - old_part_names))
671    elif old_part_names == part_names:
672      if self.payload.IsFull():
673        raise PayloadError('trying to apply a full update onto src partitions')
674    else:
675      raise PayloadError('not all src partitions provided')
676
677    for name, operations in install_operations:
678      # Apply update to partition.
679      self._ApplyToPartition(
680          operations, name, '%s_install_operations' % name, new_parts[name],
681          new_part_info[name], old_parts.get(name, None), old_part_info[name])
682