1# 2# Copyright (C) 2013 The Android Open Source Project 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# 16 17"""Applying a Chrome OS update payload. 18 19This module is used internally by the main Payload class for applying an update 20payload. The interface for invoking the applier is as follows: 21 22 applier = PayloadApplier(payload) 23 applier.Run(...) 24 25""" 26 27from __future__ import absolute_import 28from __future__ import print_function 29 30import array 31import bz2 32import hashlib 33# Not everywhere we can have the lzma library so we ignore it if we didn't have 34# it because it is not going to be used. For example, 'cros flash' uses 35# devserver code which eventually loads this file, but the lzma library is not 36# included in the client test devices, and it is not necessary to do so. But 37# lzma is not used in 'cros flash' so it should be fine. Python 3.x include 38# lzma, but for backward compatibility with Python 2.7, backports-lzma is 39# needed. 40try: 41 import lzma 42except ImportError: 43 try: 44 from backports import lzma 45 except ImportError: 46 pass 47import os 48import subprocess 49import sys 50import tempfile 51 52from update_payload import common 53from update_payload.error import PayloadError 54 55# 56# Helper functions. 57# 58def _VerifySha256(file_obj, expected_hash, name, length=-1): 59 """Verifies the SHA256 hash of a file. 60 61 Args: 62 file_obj: file object to read 63 expected_hash: the hash digest we expect to be getting 64 name: name string of this hash, for error reporting 65 length: precise length of data to verify (optional) 66 67 Raises: 68 PayloadError if computed hash doesn't match expected one, or if fails to 69 read the specified length of data. 70 """ 71 hasher = hashlib.sha256() 72 block_length = 1024 * 1024 73 max_length = length if length >= 0 else sys.maxsize 74 75 while max_length > 0: 76 read_length = min(max_length, block_length) 77 data = file_obj.read(read_length) 78 if not data: 79 break 80 max_length -= len(data) 81 hasher.update(data) 82 83 if length >= 0 and max_length > 0: 84 raise PayloadError( 85 'insufficient data (%d instead of %d) when verifying %s' % 86 (length - max_length, length, name)) 87 88 actual_hash = hasher.digest() 89 if actual_hash != expected_hash: 90 raise PayloadError('%s hash (%s) not as expected (%s)' % 91 (name, common.FormatSha256(actual_hash), 92 common.FormatSha256(expected_hash))) 93 94 95def _ReadExtents(file_obj, extents, block_size, max_length=-1): 96 """Reads data from file as defined by extent sequence. 97 98 This tries to be efficient by not copying data as it is read in chunks. 99 100 Args: 101 file_obj: file object 102 extents: sequence of block extents (offset and length) 103 block_size: size of each block 104 max_length: maximum length to read (optional) 105 106 Returns: 107 A character array containing the concatenated read data. 108 """ 109 data = array.array('B') 110 if max_length < 0: 111 max_length = sys.maxsize 112 for ex in extents: 113 if max_length == 0: 114 break 115 read_length = min(max_length, ex.num_blocks * block_size) 116 117 file_obj.seek(ex.start_block * block_size) 118 data.fromfile(file_obj, read_length) 119 120 max_length -= read_length 121 122 return data 123 124 125def _WriteExtents(file_obj, data, extents, block_size, base_name): 126 """Writes data to file as defined by extent sequence. 127 128 This tries to be efficient by not copy data as it is written in chunks. 129 130 Args: 131 file_obj: file object 132 data: data to write 133 extents: sequence of block extents (offset and length) 134 block_size: size of each block 135 base_name: name string of extent sequence for error reporting 136 137 Raises: 138 PayloadError when things don't add up. 139 """ 140 data_offset = 0 141 data_length = len(data) 142 for ex, ex_name in common.ExtentIter(extents, base_name): 143 if not data_length: 144 raise PayloadError('%s: more write extents than data' % ex_name) 145 write_length = min(data_length, ex.num_blocks * block_size) 146 file_obj.seek(ex.start_block * block_size) 147 file_obj.write(data[data_offset:(data_offset + write_length)]) 148 149 data_offset += write_length 150 data_length -= write_length 151 152 if data_length: 153 raise PayloadError('%s: more data than write extents' % base_name) 154 155 156def _ExtentsToBspatchArg(extents, block_size, base_name, data_length=-1): 157 """Translates an extent sequence into a bspatch-compatible string argument. 158 159 Args: 160 extents: sequence of block extents (offset and length) 161 block_size: size of each block 162 base_name: name string of extent sequence for error reporting 163 data_length: the actual total length of the data in bytes (optional) 164 165 Returns: 166 A tuple consisting of (i) a string of the form 167 "off_1:len_1,...,off_n:len_n", (ii) an offset where zero padding is needed 168 for filling the last extent, (iii) the length of the padding (zero means no 169 padding is needed and the extents cover the full length of data). 170 171 Raises: 172 PayloadError if data_length is too short or too long. 173 """ 174 arg = '' 175 pad_off = pad_len = 0 176 if data_length < 0: 177 data_length = sys.maxsize 178 for ex, ex_name in common.ExtentIter(extents, base_name): 179 if not data_length: 180 raise PayloadError('%s: more extents than total data length' % ex_name) 181 182 start_byte = ex.start_block * block_size 183 num_bytes = ex.num_blocks * block_size 184 if data_length < num_bytes: 185 # We're only padding a real extent. 186 pad_off = start_byte + data_length 187 pad_len = num_bytes - data_length 188 num_bytes = data_length 189 190 arg += '%s%d:%d' % (arg and ',', start_byte, num_bytes) 191 data_length -= num_bytes 192 193 if data_length: 194 raise PayloadError('%s: extents not covering full data length' % base_name) 195 196 return arg, pad_off, pad_len 197 198 199# 200# Payload application. 201# 202class PayloadApplier(object): 203 """Applying an update payload. 204 205 This is a short-lived object whose purpose is to isolate the logic used for 206 applying an update payload. 207 """ 208 209 def __init__(self, payload, bsdiff_in_place=True, bspatch_path=None, 210 puffpatch_path=None, truncate_to_expected_size=True): 211 """Initialize the applier. 212 213 Args: 214 payload: the payload object to check 215 bsdiff_in_place: whether to perform BSDIFF operation in-place (optional) 216 bspatch_path: path to the bspatch binary (optional) 217 puffpatch_path: path to the puffpatch binary (optional) 218 truncate_to_expected_size: whether to truncate the resulting partitions 219 to their expected sizes, as specified in the 220 payload (optional) 221 """ 222 assert payload.is_init, 'uninitialized update payload' 223 self.payload = payload 224 self.block_size = payload.manifest.block_size 225 self.minor_version = payload.manifest.minor_version 226 self.bsdiff_in_place = bsdiff_in_place 227 self.bspatch_path = bspatch_path or 'bspatch' 228 self.puffpatch_path = puffpatch_path or 'puffin' 229 self.truncate_to_expected_size = truncate_to_expected_size 230 231 def _ApplyReplaceOperation(self, op, op_name, out_data, part_file, part_size): 232 """Applies a REPLACE{,_BZ,_XZ} operation. 233 234 Args: 235 op: the operation object 236 op_name: name string for error reporting 237 out_data: the data to be written 238 part_file: the partition file object 239 part_size: the size of the partition 240 241 Raises: 242 PayloadError if something goes wrong. 243 """ 244 block_size = self.block_size 245 data_length = len(out_data) 246 247 # Decompress data if needed. 248 if op.type == common.OpType.REPLACE_BZ: 249 out_data = bz2.decompress(out_data) 250 data_length = len(out_data) 251 elif op.type == common.OpType.REPLACE_XZ: 252 # pylint: disable=no-member 253 out_data = lzma.decompress(out_data) 254 data_length = len(out_data) 255 256 # Write data to blocks specified in dst extents. 257 data_start = 0 258 for ex, ex_name in common.ExtentIter(op.dst_extents, 259 '%s.dst_extents' % op_name): 260 start_block = ex.start_block 261 num_blocks = ex.num_blocks 262 count = num_blocks * block_size 263 264 data_end = data_start + count 265 266 # Make sure we're not running past partition boundary. 267 if (start_block + num_blocks) * block_size > part_size: 268 raise PayloadError( 269 '%s: extent (%s) exceeds partition size (%d)' % 270 (ex_name, common.FormatExtent(ex, block_size), 271 part_size)) 272 273 # Make sure that we have enough data to write. 274 if data_end >= data_length + block_size: 275 raise PayloadError( 276 '%s: more dst blocks than data (even with padding)') 277 278 # Pad with zeros if necessary. 279 if data_end > data_length: 280 padding = data_end - data_length 281 out_data += b'\0' * padding 282 283 self.payload.payload_file.seek(start_block * block_size) 284 part_file.seek(start_block * block_size) 285 part_file.write(out_data[data_start:data_end]) 286 287 data_start += count 288 289 # Make sure we wrote all data. 290 if data_start < data_length: 291 raise PayloadError('%s: wrote fewer bytes (%d) than expected (%d)' % 292 (op_name, data_start, data_length)) 293 294 def _ApplyZeroOperation(self, op, op_name, part_file): 295 """Applies a ZERO operation. 296 297 Args: 298 op: the operation object 299 op_name: name string for error reporting 300 part_file: the partition file object 301 302 Raises: 303 PayloadError if something goes wrong. 304 """ 305 block_size = self.block_size 306 base_name = '%s.dst_extents' % op_name 307 308 # Iterate over the extents and write zero. 309 # pylint: disable=unused-variable 310 for ex, ex_name in common.ExtentIter(op.dst_extents, base_name): 311 part_file.seek(ex.start_block * block_size) 312 part_file.write(b'\0' * (ex.num_blocks * block_size)) 313 314 def _ApplySourceCopyOperation(self, op, op_name, old_part_file, 315 new_part_file): 316 """Applies a SOURCE_COPY operation. 317 318 Args: 319 op: the operation object 320 op_name: name string for error reporting 321 old_part_file: the old partition file object 322 new_part_file: the new partition file object 323 324 Raises: 325 PayloadError if something goes wrong. 326 """ 327 if not old_part_file: 328 raise PayloadError( 329 '%s: no source partition file provided for operation type (%d)' % 330 (op_name, op.type)) 331 332 block_size = self.block_size 333 334 # Gather input raw data from src extents. 335 in_data = _ReadExtents(old_part_file, op.src_extents, block_size) 336 337 # Dump extracted data to dst extents. 338 _WriteExtents(new_part_file, in_data, op.dst_extents, block_size, 339 '%s.dst_extents' % op_name) 340 341 def _BytesInExtents(self, extents, base_name): 342 """Counts the length of extents in bytes. 343 344 Args: 345 extents: The list of Extents. 346 base_name: For error reporting. 347 348 Returns: 349 The number of bytes in extents. 350 """ 351 352 length = 0 353 # pylint: disable=unused-variable 354 for ex, ex_name in common.ExtentIter(extents, base_name): 355 length += ex.num_blocks * self.block_size 356 return length 357 358 def _ApplyDiffOperation(self, op, op_name, patch_data, old_part_file, 359 new_part_file): 360 """Applies a SOURCE_BSDIFF, BROTLI_BSDIFF or PUFFDIFF operation. 361 362 Args: 363 op: the operation object 364 op_name: name string for error reporting 365 patch_data: the binary patch content 366 old_part_file: the source partition file object 367 new_part_file: the target partition file object 368 369 Raises: 370 PayloadError if something goes wrong. 371 """ 372 if not old_part_file: 373 raise PayloadError( 374 '%s: no source partition file provided for operation type (%d)' % 375 (op_name, op.type)) 376 377 block_size = self.block_size 378 379 # Dump patch data to file. 380 with tempfile.NamedTemporaryFile(delete=False) as patch_file: 381 patch_file_name = patch_file.name 382 patch_file.write(patch_data) 383 384 if (hasattr(new_part_file, 'fileno') and 385 ((not old_part_file) or hasattr(old_part_file, 'fileno'))): 386 # Construct input and output extents argument for bspatch. 387 388 in_extents_arg, _, _ = _ExtentsToBspatchArg( 389 op.src_extents, block_size, '%s.src_extents' % op_name, 390 data_length=op.src_length if op.src_length else 391 self._BytesInExtents(op.src_extents, "%s.src_extents")) 392 out_extents_arg, pad_off, pad_len = _ExtentsToBspatchArg( 393 op.dst_extents, block_size, '%s.dst_extents' % op_name, 394 data_length=op.dst_length if op.dst_length else 395 self._BytesInExtents(op.dst_extents, "%s.dst_extents")) 396 397 new_file_name = '/dev/fd/%d' % new_part_file.fileno() 398 # Diff from source partition. 399 old_file_name = '/dev/fd/%d' % old_part_file.fileno() 400 401 # In python3, file descriptors(fd) are not passed to child processes by 402 # default. To pass the fds to the child processes, we need to set the flag 403 # 'inheritable' in the fds and make the subprocess calls with the argument 404 # close_fds set to False. 405 if sys.version_info.major >= 3: 406 os.set_inheritable(new_part_file.fileno(), True) 407 os.set_inheritable(old_part_file.fileno(), True) 408 409 if op.type in (common.OpType.SOURCE_BSDIFF, common.OpType.BROTLI_BSDIFF): 410 # Invoke bspatch on partition file with extents args. 411 bspatch_cmd = [self.bspatch_path, old_file_name, new_file_name, 412 patch_file_name, in_extents_arg, out_extents_arg] 413 subprocess.check_call(bspatch_cmd, close_fds=False) 414 elif op.type == common.OpType.PUFFDIFF: 415 # Invoke puffpatch on partition file with extents args. 416 puffpatch_cmd = [self.puffpatch_path, 417 "--operation=puffpatch", 418 "--src_file=%s" % old_file_name, 419 "--dst_file=%s" % new_file_name, 420 "--patch_file=%s" % patch_file_name, 421 "--src_extents=%s" % in_extents_arg, 422 "--dst_extents=%s" % out_extents_arg] 423 subprocess.check_call(puffpatch_cmd, close_fds=False) 424 else: 425 raise PayloadError("Unknown operation %s" % op.type) 426 427 # Pad with zeros past the total output length. 428 if pad_len: 429 new_part_file.seek(pad_off) 430 new_part_file.write(b'\0' * pad_len) 431 else: 432 # Gather input raw data and write to a temp file. 433 input_part_file = old_part_file if old_part_file else new_part_file 434 in_data = _ReadExtents(input_part_file, op.src_extents, block_size, 435 max_length=op.src_length if op.src_length else 436 self._BytesInExtents(op.src_extents, 437 "%s.src_extents")) 438 with tempfile.NamedTemporaryFile(delete=False) as in_file: 439 in_file_name = in_file.name 440 in_file.write(in_data) 441 442 # Allocate temporary output file. 443 with tempfile.NamedTemporaryFile(delete=False) as out_file: 444 out_file_name = out_file.name 445 446 if op.type in (common.OpType.SOURCE_BSDIFF, common.OpType.BROTLI_BSDIFF): 447 # Invoke bspatch. 448 bspatch_cmd = [self.bspatch_path, in_file_name, out_file_name, 449 patch_file_name] 450 subprocess.check_call(bspatch_cmd) 451 elif op.type == common.OpType.PUFFDIFF: 452 # Invoke puffpatch. 453 puffpatch_cmd = [self.puffpatch_path, 454 "--operation=puffpatch", 455 "--src_file=%s" % in_file_name, 456 "--dst_file=%s" % out_file_name, 457 "--patch_file=%s" % patch_file_name] 458 subprocess.check_call(puffpatch_cmd) 459 else: 460 raise PayloadError("Unknown operation %s" % op.type) 461 462 # Read output. 463 with open(out_file_name, 'rb') as out_file: 464 out_data = out_file.read() 465 if len(out_data) != op.dst_length: 466 raise PayloadError( 467 '%s: actual patched data length (%d) not as expected (%d)' % 468 (op_name, len(out_data), op.dst_length)) 469 470 # Write output back to partition, with padding. 471 unaligned_out_len = len(out_data) % block_size 472 if unaligned_out_len: 473 out_data += b'\0' * (block_size - unaligned_out_len) 474 _WriteExtents(new_part_file, out_data, op.dst_extents, block_size, 475 '%s.dst_extents' % op_name) 476 477 # Delete input/output files. 478 os.remove(in_file_name) 479 os.remove(out_file_name) 480 481 # Delete patch file. 482 os.remove(patch_file_name) 483 484 def _ApplyOperations(self, operations, base_name, old_part_file, 485 new_part_file, part_size): 486 """Applies a sequence of update operations to a partition. 487 488 Args: 489 operations: the sequence of operations 490 base_name: the name of the operation sequence 491 old_part_file: the old partition file object, open for reading/writing 492 new_part_file: the new partition file object, open for reading/writing 493 part_size: the partition size 494 495 Raises: 496 PayloadError if anything goes wrong while processing the payload. 497 """ 498 for op, op_name in common.OperationIter(operations, base_name): 499 # Read data blob. 500 data = self.payload.ReadDataBlob(op.data_offset, op.data_length) 501 502 if op.type in (common.OpType.REPLACE, common.OpType.REPLACE_BZ, 503 common.OpType.REPLACE_XZ): 504 self._ApplyReplaceOperation(op, op_name, data, new_part_file, part_size) 505 elif op.type == common.OpType.ZERO: 506 self._ApplyZeroOperation(op, op_name, new_part_file) 507 elif op.type == common.OpType.SOURCE_COPY: 508 self._ApplySourceCopyOperation(op, op_name, old_part_file, 509 new_part_file) 510 elif op.type in (common.OpType.SOURCE_BSDIFF, common.OpType.PUFFDIFF, 511 common.OpType.BROTLI_BSDIFF): 512 self._ApplyDiffOperation(op, op_name, data, old_part_file, 513 new_part_file) 514 else: 515 raise PayloadError('%s: unknown operation type (%d)' % 516 (op_name, op.type)) 517 518 def _ApplyToPartition(self, operations, part_name, base_name, 519 new_part_file_name, new_part_info, 520 old_part_file_name=None, old_part_info=None): 521 """Applies an update to a partition. 522 523 Args: 524 operations: the sequence of update operations to apply 525 part_name: the name of the partition, for error reporting 526 base_name: the name of the operation sequence 527 new_part_file_name: file name to write partition data to 528 new_part_info: size and expected hash of dest partition 529 old_part_file_name: file name of source partition (optional) 530 old_part_info: size and expected hash of source partition (optional) 531 532 Raises: 533 PayloadError if anything goes wrong with the update. 534 """ 535 # Do we have a source partition? 536 if old_part_file_name: 537 # Verify the source partition. 538 with open(old_part_file_name, 'rb') as old_part_file: 539 _VerifySha256(old_part_file, old_part_info.hash, 540 'old ' + part_name, length=old_part_info.size) 541 new_part_file_mode = 'r+b' 542 open(new_part_file_name, 'w').close() 543 544 else: 545 # We need to create/truncate the dst partition file. 546 new_part_file_mode = 'w+b' 547 548 # Apply operations. 549 with open(new_part_file_name, new_part_file_mode) as new_part_file: 550 old_part_file = (open(old_part_file_name, 'r+b') 551 if old_part_file_name else None) 552 try: 553 self._ApplyOperations(operations, base_name, old_part_file, 554 new_part_file, new_part_info.size) 555 finally: 556 if old_part_file: 557 old_part_file.close() 558 559 # Truncate the result, if so instructed. 560 if self.truncate_to_expected_size: 561 new_part_file.seek(0, 2) 562 if new_part_file.tell() > new_part_info.size: 563 new_part_file.seek(new_part_info.size) 564 new_part_file.truncate() 565 566 # Verify the resulting partition. 567 with open(new_part_file_name, 'rb') as new_part_file: 568 _VerifySha256(new_part_file, new_part_info.hash, 569 'new ' + part_name, length=new_part_info.size) 570 571 def Run(self, new_parts, old_parts=None): 572 """Applier entry point, invoking all update operations. 573 574 Args: 575 new_parts: map of partition name to dest partition file 576 old_parts: map of partition name to source partition file (optional) 577 578 Raises: 579 PayloadError if payload application failed. 580 """ 581 if old_parts is None: 582 old_parts = {} 583 584 self.payload.ResetFile() 585 586 new_part_info = {} 587 old_part_info = {} 588 install_operations = [] 589 590 manifest = self.payload.manifest 591 for part in manifest.partitions: 592 name = part.partition_name 593 new_part_info[name] = part.new_partition_info 594 old_part_info[name] = part.old_partition_info 595 install_operations.append((name, part.operations)) 596 597 part_names = set(new_part_info.keys()) # Equivalently, old_part_info.keys() 598 599 # Make sure the arguments are sane and match the payload. 600 new_part_names = set(new_parts.keys()) 601 if new_part_names != part_names: 602 raise PayloadError('missing dst partition(s) %s' % 603 ', '.join(part_names - new_part_names)) 604 605 old_part_names = set(old_parts.keys()) 606 if part_names - old_part_names: 607 if self.payload.IsDelta(): 608 raise PayloadError('trying to apply a delta update without src ' 609 'partition(s) %s' % 610 ', '.join(part_names - old_part_names)) 611 elif old_part_names == part_names: 612 if self.payload.IsFull(): 613 raise PayloadError('trying to apply a full update onto src partitions') 614 else: 615 raise PayloadError('not all src partitions provided') 616 617 for name, operations in install_operations: 618 # Apply update to partition. 619 self._ApplyToPartition( 620 operations, name, '%s_install_operations' % name, new_parts[name], 621 new_part_info[name], old_parts.get(name, None), old_part_info[name]) 622