1# Copyright (C) 2014 The Android Open Source Project 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from __future__ import print_function 16 17from collections import deque, OrderedDict 18from hashlib import sha1 19import itertools 20import multiprocessing 21import os 22import pprint 23import re 24import subprocess 25import sys 26import threading 27import tempfile 28 29from rangelib import * 30 31__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"] 32 33def compute_patch(src, tgt, imgdiff=False): 34 srcfd, srcfile = tempfile.mkstemp(prefix="src-") 35 tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-") 36 patchfd, patchfile = tempfile.mkstemp(prefix="patch-") 37 os.close(patchfd) 38 39 try: 40 with os.fdopen(srcfd, "wb") as f_src: 41 for p in src: 42 f_src.write(p) 43 44 with os.fdopen(tgtfd, "wb") as f_tgt: 45 for p in tgt: 46 f_tgt.write(p) 47 try: 48 os.unlink(patchfile) 49 except OSError: 50 pass 51 if imgdiff: 52 p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile], 53 stdout=open("/dev/null", "a"), 54 stderr=subprocess.STDOUT) 55 else: 56 p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile]) 57 58 if p: 59 raise ValueError("diff failed: " + str(p)) 60 61 with open(patchfile, "rb") as f: 62 return f.read() 63 finally: 64 try: 65 os.unlink(srcfile) 66 os.unlink(tgtfile) 67 os.unlink(patchfile) 68 except OSError: 69 pass 70 71class EmptyImage(object): 72 """A zero-length image.""" 73 blocksize = 4096 74 care_map = RangeSet() 75 total_blocks = 0 76 file_map = {} 77 def ReadRangeSet(self, ranges): 78 return () 79 def TotalSha1(self): 80 return sha1().hexdigest() 81 82 83class DataImage(object): 84 """An image wrapped around a single string of data.""" 85 86 def __init__(self, data, trim=False, pad=False): 87 self.data = data 88 self.blocksize = 4096 89 90 assert not (trim and pad) 91 92 partial = len(self.data) % self.blocksize 93 if partial > 0: 94 if trim: 95 self.data = self.data[:-partial] 96 elif pad: 97 self.data += '\0' * (self.blocksize - partial) 98 else: 99 raise ValueError(("data for DataImage must be multiple of %d bytes " 100 "unless trim or pad is specified") % 101 (self.blocksize,)) 102 103 assert len(self.data) % self.blocksize == 0 104 105 self.total_blocks = len(self.data) / self.blocksize 106 self.care_map = RangeSet(data=(0, self.total_blocks)) 107 108 zero_blocks = [] 109 nonzero_blocks = [] 110 reference = '\0' * self.blocksize 111 112 for i in range(self.total_blocks): 113 d = self.data[i*self.blocksize : (i+1)*self.blocksize] 114 if d == reference: 115 zero_blocks.append(i) 116 zero_blocks.append(i+1) 117 else: 118 nonzero_blocks.append(i) 119 nonzero_blocks.append(i+1) 120 121 self.file_map = {"__ZERO": RangeSet(zero_blocks), 122 "__NONZERO": RangeSet(nonzero_blocks)} 123 124 def ReadRangeSet(self, ranges): 125 return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges] 126 127 def TotalSha1(self): 128 if not hasattr(self, "sha1"): 129 self.sha1 = sha1(self.data).hexdigest() 130 return self.sha1 131 132 133class Transfer(object): 134 def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id): 135 self.tgt_name = tgt_name 136 self.src_name = src_name 137 self.tgt_ranges = tgt_ranges 138 self.src_ranges = src_ranges 139 self.style = style 140 self.intact = (getattr(tgt_ranges, "monotonic", False) and 141 getattr(src_ranges, "monotonic", False)) 142 self.goes_before = {} 143 self.goes_after = {} 144 145 self.id = len(by_id) 146 by_id.append(self) 147 148 def __str__(self): 149 return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style + 150 " to " + str(self.tgt_ranges) + ">") 151 152 153# BlockImageDiff works on two image objects. An image object is 154# anything that provides the following attributes: 155# 156# blocksize: the size in bytes of a block, currently must be 4096. 157# 158# total_blocks: the total size of the partition/image, in blocks. 159# 160# care_map: a RangeSet containing which blocks (in the range [0, 161# total_blocks) we actually care about; i.e. which blocks contain 162# data. 163# 164# file_map: a dict that partitions the blocks contained in care_map 165# into smaller domains that are useful for doing diffs on. 166# (Typically a domain is a file, and the key in file_map is the 167# pathname.) 168# 169# ReadRangeSet(): a function that takes a RangeSet and returns the 170# data contained in the image blocks of that RangeSet. The data 171# is returned as a list or tuple of strings; concatenating the 172# elements together should produce the requested data. 173# Implementations are free to break up the data into list/tuple 174# elements in any way that is convenient. 175# 176# TotalSha1(): a function that returns (as a hex string) the SHA-1 177# hash of all the data in the image (ie, all the blocks in the 178# care_map) 179# 180# When creating a BlockImageDiff, the src image may be None, in which 181# case the list of transfers produced will never read from the 182# original image. 183 184class BlockImageDiff(object): 185 def __init__(self, tgt, src=None, threads=None): 186 if threads is None: 187 threads = multiprocessing.cpu_count() // 2 188 if threads == 0: threads = 1 189 self.threads = threads 190 191 self.tgt = tgt 192 if src is None: 193 src = EmptyImage() 194 self.src = src 195 196 # The updater code that installs the patch always uses 4k blocks. 197 assert tgt.blocksize == 4096 198 assert src.blocksize == 4096 199 200 # The range sets in each filemap should comprise a partition of 201 # the care map. 202 self.AssertPartition(src.care_map, src.file_map.values()) 203 self.AssertPartition(tgt.care_map, tgt.file_map.values()) 204 205 def Compute(self, prefix): 206 # When looking for a source file to use as the diff input for a 207 # target file, we try: 208 # 1) an exact path match if available, otherwise 209 # 2) a exact basename match if available, otherwise 210 # 3) a basename match after all runs of digits are replaced by 211 # "#" if available, otherwise 212 # 4) we have no source for this target. 213 self.AbbreviateSourceNames() 214 self.FindTransfers() 215 216 # Find the ordering dependencies among transfers (this is O(n^2) 217 # in the number of transfers). 218 self.GenerateDigraph() 219 # Find a sequence of transfers that satisfies as many ordering 220 # dependencies as possible (heuristically). 221 self.FindVertexSequence() 222 # Fix up the ordering dependencies that the sequence didn't 223 # satisfy. 224 self.RemoveBackwardEdges() 225 # Double-check our work. 226 self.AssertSequenceGood() 227 228 self.ComputePatches(prefix) 229 self.WriteTransfers(prefix) 230 231 def WriteTransfers(self, prefix): 232 out = [] 233 234 out.append("1\n") # format version number 235 total = 0 236 performs_read = False 237 238 for xf in self.transfers: 239 240 # zero [rangeset] 241 # new [rangeset] 242 # bsdiff patchstart patchlen [src rangeset] [tgt rangeset] 243 # imgdiff patchstart patchlen [src rangeset] [tgt rangeset] 244 # move [src rangeset] [tgt rangeset] 245 # erase [rangeset] 246 247 tgt_size = xf.tgt_ranges.size() 248 249 if xf.style == "new": 250 assert xf.tgt_ranges 251 out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw())) 252 total += tgt_size 253 elif xf.style == "move": 254 performs_read = True 255 assert xf.tgt_ranges 256 assert xf.src_ranges.size() == tgt_size 257 if xf.src_ranges != xf.tgt_ranges: 258 out.append("%s %s %s\n" % ( 259 xf.style, 260 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 261 total += tgt_size 262 elif xf.style in ("bsdiff", "imgdiff"): 263 performs_read = True 264 assert xf.tgt_ranges 265 assert xf.src_ranges 266 out.append("%s %d %d %s %s\n" % ( 267 xf.style, xf.patch_start, xf.patch_len, 268 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 269 total += tgt_size 270 elif xf.style == "zero": 271 assert xf.tgt_ranges 272 to_zero = xf.tgt_ranges.subtract(xf.src_ranges) 273 if to_zero: 274 out.append("%s %s\n" % (xf.style, to_zero.to_string_raw())) 275 total += to_zero.size() 276 else: 277 raise ValueError, "unknown transfer style '%s'\n" % (xf.style,) 278 279 out.insert(1, str(total) + "\n") 280 281 all_tgt = RangeSet(data=(0, self.tgt.total_blocks)) 282 if performs_read: 283 # if some of the original data is used, then at the end we'll 284 # erase all the blocks on the partition that don't contain data 285 # in the new image. 286 new_dontcare = all_tgt.subtract(self.tgt.care_map) 287 if new_dontcare: 288 out.append("erase %s\n" % (new_dontcare.to_string_raw(),)) 289 else: 290 # if nothing is read (ie, this is a full OTA), then we can start 291 # by erasing the entire partition. 292 out.insert(2, "erase %s\n" % (all_tgt.to_string_raw(),)) 293 294 with open(prefix + ".transfer.list", "wb") as f: 295 for i in out: 296 f.write(i) 297 298 def ComputePatches(self, prefix): 299 print("Reticulating splines...") 300 diff_q = [] 301 patch_num = 0 302 with open(prefix + ".new.dat", "wb") as new_f: 303 for xf in self.transfers: 304 if xf.style == "zero": 305 pass 306 elif xf.style == "new": 307 for piece in self.tgt.ReadRangeSet(xf.tgt_ranges): 308 new_f.write(piece) 309 elif xf.style == "diff": 310 src = self.src.ReadRangeSet(xf.src_ranges) 311 tgt = self.tgt.ReadRangeSet(xf.tgt_ranges) 312 313 # We can't compare src and tgt directly because they may have 314 # the same content but be broken up into blocks differently, eg: 315 # 316 # ["he", "llo"] vs ["h", "ello"] 317 # 318 # We want those to compare equal, ideally without having to 319 # actually concatenate the strings (these may be tens of 320 # megabytes). 321 322 src_sha1 = sha1() 323 for p in src: 324 src_sha1.update(p) 325 tgt_sha1 = sha1() 326 tgt_size = 0 327 for p in tgt: 328 tgt_sha1.update(p) 329 tgt_size += len(p) 330 331 if src_sha1.digest() == tgt_sha1.digest(): 332 # These are identical; we don't need to generate a patch, 333 # just issue copy commands on the device. 334 xf.style = "move" 335 else: 336 # For files in zip format (eg, APKs, JARs, etc.) we would 337 # like to use imgdiff -z if possible (because it usually 338 # produces significantly smaller patches than bsdiff). 339 # This is permissible if: 340 # 341 # - the source and target files are monotonic (ie, the 342 # data is stored with blocks in increasing order), and 343 # - we haven't removed any blocks from the source set. 344 # 345 # If these conditions are satisfied then appending all the 346 # blocks in the set together in order will produce a valid 347 # zip file (plus possibly extra zeros in the last block), 348 # which is what imgdiff needs to operate. (imgdiff is 349 # fine with extra zeros at the end of the file.) 350 imgdiff = (xf.intact and 351 xf.tgt_name.split(".")[-1].lower() 352 in ("apk", "jar", "zip")) 353 xf.style = "imgdiff" if imgdiff else "bsdiff" 354 diff_q.append((tgt_size, src, tgt, xf, patch_num)) 355 patch_num += 1 356 357 else: 358 assert False, "unknown style " + xf.style 359 360 if diff_q: 361 if self.threads > 1: 362 print("Computing patches (using %d threads)..." % (self.threads,)) 363 else: 364 print("Computing patches...") 365 diff_q.sort() 366 367 patches = [None] * patch_num 368 369 lock = threading.Lock() 370 def diff_worker(): 371 while True: 372 with lock: 373 if not diff_q: return 374 tgt_size, src, tgt, xf, patchnum = diff_q.pop() 375 patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff")) 376 size = len(patch) 377 with lock: 378 patches[patchnum] = (patch, xf) 379 print("%10d %10d (%6.2f%%) %7s %s" % ( 380 size, tgt_size, size * 100.0 / tgt_size, xf.style, 381 xf.tgt_name if xf.tgt_name == xf.src_name else ( 382 xf.tgt_name + " (from " + xf.src_name + ")"))) 383 384 threads = [threading.Thread(target=diff_worker) 385 for i in range(self.threads)] 386 for th in threads: 387 th.start() 388 while threads: 389 threads.pop().join() 390 else: 391 patches = [] 392 393 p = 0 394 with open(prefix + ".patch.dat", "wb") as patch_f: 395 for patch, xf in patches: 396 xf.patch_start = p 397 xf.patch_len = len(patch) 398 patch_f.write(patch) 399 p += len(patch) 400 401 def AssertSequenceGood(self): 402 # Simulate the sequences of transfers we will output, and check that: 403 # - we never read a block after writing it, and 404 # - we write every block we care about exactly once. 405 406 # Start with no blocks having been touched yet. 407 touched = RangeSet() 408 409 # Imagine processing the transfers in order. 410 for xf in self.transfers: 411 # Check that the input blocks for this transfer haven't yet been touched. 412 assert not touched.overlaps(xf.src_ranges) 413 # Check that the output blocks for this transfer haven't yet been touched. 414 assert not touched.overlaps(xf.tgt_ranges) 415 # Touch all the blocks written by this transfer. 416 touched = touched.union(xf.tgt_ranges) 417 418 # Check that we've written every target block. 419 assert touched == self.tgt.care_map 420 421 def RemoveBackwardEdges(self): 422 print("Removing backward edges...") 423 in_order = 0 424 out_of_order = 0 425 lost_source = 0 426 427 for xf in self.transfers: 428 io = 0 429 ooo = 0 430 lost = 0 431 size = xf.src_ranges.size() 432 for u in xf.goes_before: 433 # xf should go before u 434 if xf.order < u.order: 435 # it does, hurray! 436 io += 1 437 else: 438 # it doesn't, boo. trim the blocks that u writes from xf's 439 # source, so that xf can go after u. 440 ooo += 1 441 assert xf.src_ranges.overlaps(u.tgt_ranges) 442 xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges) 443 xf.intact = False 444 445 if xf.style == "diff" and not xf.src_ranges: 446 # nothing left to diff from; treat as new data 447 xf.style = "new" 448 449 lost = size - xf.src_ranges.size() 450 lost_source += lost 451 in_order += io 452 out_of_order += ooo 453 454 print((" %d/%d dependencies (%.2f%%) were violated; " 455 "%d source blocks removed.") % 456 (out_of_order, in_order + out_of_order, 457 (out_of_order * 100.0 / (in_order + out_of_order)) 458 if (in_order + out_of_order) else 0.0, 459 lost_source)) 460 461 def FindVertexSequence(self): 462 print("Finding vertex sequence...") 463 464 # This is based on "A Fast & Effective Heuristic for the Feedback 465 # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth. Think of 466 # it as starting with the digraph G and moving all the vertices to 467 # be on a horizontal line in some order, trying to minimize the 468 # number of edges that end up pointing to the left. Left-pointing 469 # edges will get removed to turn the digraph into a DAG. In this 470 # case each edge has a weight which is the number of source blocks 471 # we'll lose if that edge is removed; we try to minimize the total 472 # weight rather than just the number of edges. 473 474 # Make a copy of the edge set; this copy will get destroyed by the 475 # algorithm. 476 for xf in self.transfers: 477 xf.incoming = xf.goes_after.copy() 478 xf.outgoing = xf.goes_before.copy() 479 480 # We use an OrderedDict instead of just a set so that the output 481 # is repeatable; otherwise it would depend on the hash values of 482 # the transfer objects. 483 G = OrderedDict() 484 for xf in self.transfers: 485 G[xf] = None 486 s1 = deque() # the left side of the sequence, built from left to right 487 s2 = deque() # the right side of the sequence, built from right to left 488 489 while G: 490 491 # Put all sinks at the end of the sequence. 492 while True: 493 sinks = [u for u in G if not u.outgoing] 494 if not sinks: break 495 for u in sinks: 496 s2.appendleft(u) 497 del G[u] 498 for iu in u.incoming: 499 del iu.outgoing[u] 500 501 # Put all the sources at the beginning of the sequence. 502 while True: 503 sources = [u for u in G if not u.incoming] 504 if not sources: break 505 for u in sources: 506 s1.append(u) 507 del G[u] 508 for iu in u.outgoing: 509 del iu.incoming[u] 510 511 if not G: break 512 513 # Find the "best" vertex to put next. "Best" is the one that 514 # maximizes the net difference in source blocks saved we get by 515 # pretending it's a source rather than a sink. 516 517 max_d = None 518 best_u = None 519 for u in G: 520 d = sum(u.outgoing.values()) - sum(u.incoming.values()) 521 if best_u is None or d > max_d: 522 max_d = d 523 best_u = u 524 525 u = best_u 526 s1.append(u) 527 del G[u] 528 for iu in u.outgoing: 529 del iu.incoming[u] 530 for iu in u.incoming: 531 del iu.outgoing[u] 532 533 # Now record the sequence in the 'order' field of each transfer, 534 # and by rearranging self.transfers to be in the chosen sequence. 535 536 new_transfers = [] 537 for x in itertools.chain(s1, s2): 538 x.order = len(new_transfers) 539 new_transfers.append(x) 540 del x.incoming 541 del x.outgoing 542 543 self.transfers = new_transfers 544 545 def GenerateDigraph(self): 546 print("Generating digraph...") 547 for a in self.transfers: 548 for b in self.transfers: 549 if a is b: continue 550 551 # If the blocks written by A are read by B, then B needs to go before A. 552 i = a.tgt_ranges.intersect(b.src_ranges) 553 if i: 554 if b.src_name == "__ZERO": 555 # the cost of removing source blocks for the __ZERO domain 556 # is (nearly) zero. 557 size = 0 558 else: 559 size = i.size() 560 b.goes_before[a] = size 561 a.goes_after[b] = size 562 563 def FindTransfers(self): 564 self.transfers = [] 565 empty = RangeSet() 566 for tgt_fn, tgt_ranges in self.tgt.file_map.items(): 567 if tgt_fn == "__ZERO": 568 # the special "__ZERO" domain is all the blocks not contained 569 # in any file and that are filled with zeros. We have a 570 # special transfer style for zero blocks. 571 src_ranges = self.src.file_map.get("__ZERO", empty) 572 Transfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges, 573 "zero", self.transfers) 574 continue 575 576 elif tgt_fn in self.src.file_map: 577 # Look for an exact pathname match in the source. 578 Transfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn], 579 "diff", self.transfers) 580 continue 581 582 b = os.path.basename(tgt_fn) 583 if b in self.src_basenames: 584 # Look for an exact basename match in the source. 585 src_fn = self.src_basenames[b] 586 Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 587 "diff", self.transfers) 588 continue 589 590 b = re.sub("[0-9]+", "#", b) 591 if b in self.src_numpatterns: 592 # Look for a 'number pattern' match (a basename match after 593 # all runs of digits are replaced by "#"). (This is useful 594 # for .so files that contain version numbers in the filename 595 # that get bumped.) 596 src_fn = self.src_numpatterns[b] 597 Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 598 "diff", self.transfers) 599 continue 600 601 Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers) 602 603 def AbbreviateSourceNames(self): 604 self.src_basenames = {} 605 self.src_numpatterns = {} 606 607 for k in self.src.file_map.keys(): 608 b = os.path.basename(k) 609 self.src_basenames[b] = k 610 b = re.sub("[0-9]+", "#", b) 611 self.src_numpatterns[b] = k 612 613 @staticmethod 614 def AssertPartition(total, seq): 615 """Assert that all the RangeSets in 'seq' form a partition of the 616 'total' RangeSet (ie, they are nonintersecting and their union 617 equals 'total').""" 618 so_far = RangeSet() 619 for i in seq: 620 assert not so_far.overlaps(i) 621 so_far = so_far.union(i) 622 assert so_far == total 623