• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (C) 2014 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import print_function
16
17import array
18import copy
19import functools
20import heapq
21import itertools
22import logging
23import multiprocessing
24import os
25import os.path
26import re
27import sys
28import threading
29import zlib
30from collections import deque, namedtuple, OrderedDict
31from hashlib import sha1
32
33import common
34from rangelib import RangeSet
35
36__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]
37
38logger = logging.getLogger(__name__)
39
40# The tuple contains the style and bytes of a bsdiff|imgdiff patch.
41PatchInfo = namedtuple("PatchInfo", ["imgdiff", "content"])
42
43
44def compute_patch(srcfile, tgtfile, imgdiff=False):
45  """Calls bsdiff|imgdiff to compute the patch data, returns a PatchInfo."""
46  patchfile = common.MakeTempFile(prefix='patch-')
47
48  cmd = ['imgdiff', '-z'] if imgdiff else ['bsdiff']
49  cmd.extend([srcfile, tgtfile, patchfile])
50
51  # Don't dump the bsdiff/imgdiff commands, which are not useful for the case
52  # here, since they contain temp filenames only.
53  proc = common.Run(cmd, verbose=False)
54  output, _ = proc.communicate()
55
56  if proc.returncode != 0:
57    raise ValueError(output)
58
59  with open(patchfile, 'rb') as f:
60    return PatchInfo(imgdiff, f.read())
61
62
63class Image(object):
64  def RangeSha1(self, ranges):
65    raise NotImplementedError
66
67  def ReadRangeSet(self, ranges):
68    raise NotImplementedError
69
70  def TotalSha1(self, include_clobbered_blocks=False):
71    raise NotImplementedError
72
73  def WriteRangeDataToFd(self, ranges, fd):
74    raise NotImplementedError
75
76
77class EmptyImage(Image):
78  """A zero-length image."""
79
80  def __init__(self):
81    self.blocksize = 4096
82    self.care_map = RangeSet()
83    self.clobbered_blocks = RangeSet()
84    self.extended = RangeSet()
85    self.total_blocks = 0
86    self.file_map = {}
87    self.hashtree_info = None
88
89  def RangeSha1(self, ranges):
90    return sha1().hexdigest()
91
92  def ReadRangeSet(self, ranges):
93    return ()
94
95  def TotalSha1(self, include_clobbered_blocks=False):
96    # EmptyImage always carries empty clobbered_blocks, so
97    # include_clobbered_blocks can be ignored.
98    assert self.clobbered_blocks.size() == 0
99    return sha1().hexdigest()
100
101  def WriteRangeDataToFd(self, ranges, fd):
102    raise ValueError("Can't write data from EmptyImage to file")
103
104
105class DataImage(Image):
106  """An image wrapped around a single string of data."""
107
108  def __init__(self, data, trim=False, pad=False):
109    self.data = data
110    self.blocksize = 4096
111
112    assert not (trim and pad)
113
114    partial = len(self.data) % self.blocksize
115    padded = False
116    if partial > 0:
117      if trim:
118        self.data = self.data[:-partial]
119      elif pad:
120        self.data += '\0' * (self.blocksize - partial)
121        padded = True
122      else:
123        raise ValueError(("data for DataImage must be multiple of %d bytes "
124                          "unless trim or pad is specified") %
125                         (self.blocksize,))
126
127    assert len(self.data) % self.blocksize == 0
128
129    self.total_blocks = len(self.data) / self.blocksize
130    self.care_map = RangeSet(data=(0, self.total_blocks))
131    # When the last block is padded, we always write the whole block even for
132    # incremental OTAs. Because otherwise the last block may get skipped if
133    # unchanged for an incremental, but would fail the post-install
134    # verification if it has non-zero contents in the padding bytes.
135    # Bug: 23828506
136    if padded:
137      clobbered_blocks = [self.total_blocks-1, self.total_blocks]
138    else:
139      clobbered_blocks = []
140    self.clobbered_blocks = clobbered_blocks
141    self.extended = RangeSet()
142
143    zero_blocks = []
144    nonzero_blocks = []
145    reference = '\0' * self.blocksize
146
147    for i in range(self.total_blocks-1 if padded else self.total_blocks):
148      d = self.data[i*self.blocksize : (i+1)*self.blocksize]
149      if d == reference:
150        zero_blocks.append(i)
151        zero_blocks.append(i+1)
152      else:
153        nonzero_blocks.append(i)
154        nonzero_blocks.append(i+1)
155
156    assert zero_blocks or nonzero_blocks or clobbered_blocks
157
158    self.file_map = dict()
159    if zero_blocks:
160      self.file_map["__ZERO"] = RangeSet(data=zero_blocks)
161    if nonzero_blocks:
162      self.file_map["__NONZERO"] = RangeSet(data=nonzero_blocks)
163    if clobbered_blocks:
164      self.file_map["__COPY"] = RangeSet(data=clobbered_blocks)
165
166  def _GetRangeData(self, ranges):
167    for s, e in ranges:
168      yield self.data[s*self.blocksize:e*self.blocksize]
169
170  def RangeSha1(self, ranges):
171    h = sha1()
172    for data in self._GetRangeData(ranges): # pylint: disable=not-an-iterable
173      h.update(data)
174    return h.hexdigest()
175
176  def ReadRangeSet(self, ranges):
177    return list(self._GetRangeData(ranges))
178
179  def TotalSha1(self, include_clobbered_blocks=False):
180    if not include_clobbered_blocks:
181      return self.RangeSha1(self.care_map.subtract(self.clobbered_blocks))
182    else:
183      return sha1(self.data).hexdigest()
184
185  def WriteRangeDataToFd(self, ranges, fd):
186    for data in self._GetRangeData(ranges): # pylint: disable=not-an-iterable
187      fd.write(data)
188
189
190class FileImage(Image):
191  """An image wrapped around a raw image file."""
192
193  def __init__(self, path, hashtree_info_generator=None):
194    self.path = path
195    self.blocksize = 4096
196    self._file_size = os.path.getsize(self.path)
197    self._file = open(self.path, 'r')
198
199    if self._file_size % self.blocksize != 0:
200      raise ValueError("Size of file %s must be multiple of %d bytes, but is %d"
201                       % self.path, self.blocksize, self._file_size)
202
203    self.total_blocks = self._file_size / self.blocksize
204    self.care_map = RangeSet(data=(0, self.total_blocks))
205    self.clobbered_blocks = RangeSet()
206    self.extended = RangeSet()
207
208    self.generator_lock = threading.Lock()
209
210    self.hashtree_info = None
211    if hashtree_info_generator:
212      self.hashtree_info = hashtree_info_generator.Generate(self)
213
214    zero_blocks = []
215    nonzero_blocks = []
216    reference = '\0' * self.blocksize
217
218    for i in range(self.total_blocks):
219      d = self._file.read(self.blocksize)
220      if d == reference:
221        zero_blocks.append(i)
222        zero_blocks.append(i+1)
223      else:
224        nonzero_blocks.append(i)
225        nonzero_blocks.append(i+1)
226
227    assert zero_blocks or nonzero_blocks
228
229    self.file_map = {}
230    if zero_blocks:
231      self.file_map["__ZERO"] = RangeSet(data=zero_blocks)
232    if nonzero_blocks:
233      self.file_map["__NONZERO"] = RangeSet(data=nonzero_blocks)
234    if self.hashtree_info:
235      self.file_map["__HASHTREE"] = self.hashtree_info.hashtree_range
236
237  def __del__(self):
238    self._file.close()
239
240  def _GetRangeData(self, ranges):
241    # Use a lock to protect the generator so that we will not run two
242    # instances of this generator on the same object simultaneously.
243    with self.generator_lock:
244      for s, e in ranges:
245        self._file.seek(s * self.blocksize)
246        for _ in range(s, e):
247          yield self._file.read(self.blocksize)
248
249  def RangeSha1(self, ranges):
250    h = sha1()
251    for data in self._GetRangeData(ranges): # pylint: disable=not-an-iterable
252      h.update(data)
253    return h.hexdigest()
254
255  def ReadRangeSet(self, ranges):
256    return list(self._GetRangeData(ranges))
257
258  def TotalSha1(self, include_clobbered_blocks=False):
259    assert not self.clobbered_blocks
260    return self.RangeSha1(self.care_map)
261
262  def WriteRangeDataToFd(self, ranges, fd):
263    for data in self._GetRangeData(ranges): # pylint: disable=not-an-iterable
264      fd.write(data)
265
266
267class Transfer(object):
268  def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, tgt_sha1,
269               src_sha1, style, by_id):
270    self.tgt_name = tgt_name
271    self.src_name = src_name
272    self.tgt_ranges = tgt_ranges
273    self.src_ranges = src_ranges
274    self.tgt_sha1 = tgt_sha1
275    self.src_sha1 = src_sha1
276    self.style = style
277
278    # We use OrderedDict rather than dict so that the output is repeatable;
279    # otherwise it would depend on the hash values of the Transfer objects.
280    self.goes_before = OrderedDict()
281    self.goes_after = OrderedDict()
282
283    self.stash_before = []
284    self.use_stash = []
285
286    self.id = len(by_id)
287    by_id.append(self)
288
289    self._patch_info = None
290
291  @property
292  def patch_info(self):
293    return self._patch_info
294
295  @patch_info.setter
296  def patch_info(self, info):
297    if info:
298      assert self.style == "diff"
299    self._patch_info = info
300
301  def NetStashChange(self):
302    return (sum(sr.size() for (_, sr) in self.stash_before) -
303            sum(sr.size() for (_, sr) in self.use_stash))
304
305  def ConvertToNew(self):
306    assert self.style != "new"
307    self.use_stash = []
308    self.style = "new"
309    self.src_ranges = RangeSet()
310    self.patch_info = None
311
312  def __str__(self):
313    return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
314            " to " + str(self.tgt_ranges) + ">")
315
316
317@functools.total_ordering
318class HeapItem(object):
319  def __init__(self, item):
320    self.item = item
321    # Negate the score since python's heap is a min-heap and we want the
322    # maximum score.
323    self.score = -item.score
324
325  def clear(self):
326    self.item = None
327
328  def __bool__(self):
329    return self.item is not None
330
331  # Python 2 uses __nonzero__, while Python 3 uses __bool__.
332  __nonzero__ = __bool__
333
334  # The rest operations are generated by functools.total_ordering decorator.
335  def __eq__(self, other):
336    return self.score == other.score
337
338  def __le__(self, other):
339    return self.score <= other.score
340
341
342class ImgdiffStats(object):
343  """A class that collects imgdiff stats.
344
345  It keeps track of the files that will be applied imgdiff while generating
346  BlockImageDiff. It also logs the ones that cannot use imgdiff, with specific
347  reasons. The stats is only meaningful when imgdiff not being disabled by the
348  caller of BlockImageDiff. In addition, only files with supported types
349  (BlockImageDiff.FileTypeSupportedByImgdiff()) are allowed to be logged.
350  """
351
352  USED_IMGDIFF = "APK files diff'd with imgdiff"
353  USED_IMGDIFF_LARGE_APK = "Large APK files split and diff'd with imgdiff"
354
355  # Reasons for not applying imgdiff on APKs.
356  SKIPPED_NONMONOTONIC = "Not used imgdiff due to having non-monotonic ranges"
357  SKIPPED_SHARED_BLOCKS = "Not used imgdiff due to using shared blocks"
358  SKIPPED_INCOMPLETE = "Not used imgdiff due to incomplete RangeSet"
359
360  # The list of valid reasons, which will also be the dumped order in a report.
361  REASONS = (
362      USED_IMGDIFF,
363      USED_IMGDIFF_LARGE_APK,
364      SKIPPED_NONMONOTONIC,
365      SKIPPED_SHARED_BLOCKS,
366      SKIPPED_INCOMPLETE,
367  )
368
369  def  __init__(self):
370    self.stats = {}
371
372  def Log(self, filename, reason):
373    """Logs why imgdiff can or cannot be applied to the given filename.
374
375    Args:
376      filename: The filename string.
377      reason: One of the reason constants listed in REASONS.
378
379    Raises:
380      AssertionError: On unsupported filetypes or invalid reason.
381    """
382    assert BlockImageDiff.FileTypeSupportedByImgdiff(filename)
383    assert reason in self.REASONS
384
385    if reason not in self.stats:
386      self.stats[reason] = set()
387    self.stats[reason].add(filename)
388
389  def Report(self):
390    """Prints a report of the collected imgdiff stats."""
391
392    def print_header(header, separator):
393      logger.info(header)
394      logger.info(separator * len(header) + '\n')
395
396    print_header('  Imgdiff Stats Report  ', '=')
397    for key in self.REASONS:
398      if key not in self.stats:
399        continue
400      values = self.stats[key]
401      section_header = ' {} (count: {}) '.format(key, len(values))
402      print_header(section_header, '-')
403      logger.info(''.join(['  {}\n'.format(name) for name in values]))
404
405
406class BlockImageDiff(object):
407  """Generates the diff of two block image objects.
408
409  BlockImageDiff works on two image objects. An image object is anything that
410  provides the following attributes:
411
412     blocksize: the size in bytes of a block, currently must be 4096.
413
414     total_blocks: the total size of the partition/image, in blocks.
415
416     care_map: a RangeSet containing which blocks (in the range [0,
417       total_blocks) we actually care about; i.e. which blocks contain data.
418
419     file_map: a dict that partitions the blocks contained in care_map into
420         smaller domains that are useful for doing diffs on. (Typically a domain
421         is a file, and the key in file_map is the pathname.)
422
423     clobbered_blocks: a RangeSet containing which blocks contain data but may
424         be altered by the FS. They need to be excluded when verifying the
425         partition integrity.
426
427     ReadRangeSet(): a function that takes a RangeSet and returns the data
428         contained in the image blocks of that RangeSet. The data is returned as
429         a list or tuple of strings; concatenating the elements together should
430         produce the requested data. Implementations are free to break up the
431         data into list/tuple elements in any way that is convenient.
432
433     RangeSha1(): a function that returns (as a hex string) the SHA-1 hash of
434         all the data in the specified range.
435
436     TotalSha1(): a function that returns (as a hex string) the SHA-1 hash of
437         all the data in the image (ie, all the blocks in the care_map minus
438         clobbered_blocks, or including the clobbered blocks if
439         include_clobbered_blocks is True).
440
441  When creating a BlockImageDiff, the src image may be None, in which case the
442  list of transfers produced will never read from the original image.
443  """
444
445  def __init__(self, tgt, src=None, threads=None, version=4,
446               disable_imgdiff=False):
447    if threads is None:
448      threads = multiprocessing.cpu_count() // 2
449      if threads == 0:
450        threads = 1
451    self.threads = threads
452    self.version = version
453    self.transfers = []
454    self.src_basenames = {}
455    self.src_numpatterns = {}
456    self._max_stashed_size = 0
457    self.touched_src_ranges = RangeSet()
458    self.touched_src_sha1 = None
459    self.disable_imgdiff = disable_imgdiff
460    self.imgdiff_stats = ImgdiffStats() if not disable_imgdiff else None
461
462    assert version in (3, 4)
463
464    self.tgt = tgt
465    if src is None:
466      src = EmptyImage()
467    self.src = src
468
469    # The updater code that installs the patch always uses 4k blocks.
470    assert tgt.blocksize == 4096
471    assert src.blocksize == 4096
472
473    # The range sets in each filemap should comprise a partition of
474    # the care map.
475    self.AssertPartition(src.care_map, src.file_map.values())
476    self.AssertPartition(tgt.care_map, tgt.file_map.values())
477
478  @property
479  def max_stashed_size(self):
480    return self._max_stashed_size
481
482  @staticmethod
483  def FileTypeSupportedByImgdiff(filename):
484    """Returns whether the file type is supported by imgdiff."""
485    return filename.lower().endswith(('.apk', '.jar', '.zip'))
486
487  def CanUseImgdiff(self, name, tgt_ranges, src_ranges, large_apk=False):
488    """Checks whether we can apply imgdiff for the given RangeSets.
489
490    For files in ZIP format (e.g., APKs, JARs, etc.) we would like to use
491    'imgdiff -z' if possible. Because it usually produces significantly smaller
492    patches than bsdiff.
493
494    This is permissible if all of the following conditions hold.
495      - The imgdiff hasn't been disabled by the caller (e.g. squashfs);
496      - The file type is supported by imgdiff;
497      - The source and target blocks are monotonic (i.e. the data is stored with
498        blocks in increasing order);
499      - Both files don't contain shared blocks;
500      - Both files have complete lists of blocks;
501      - We haven't removed any blocks from the source set.
502
503    If all these conditions are satisfied, concatenating all the blocks in the
504    RangeSet in order will produce a valid ZIP file (plus possibly extra zeros
505    in the last block). imgdiff is fine with extra zeros at the end of the file.
506
507    Args:
508      name: The filename to be diff'd.
509      tgt_ranges: The target RangeSet.
510      src_ranges: The source RangeSet.
511      large_apk: Whether this is to split a large APK.
512
513    Returns:
514      A boolean result.
515    """
516    if self.disable_imgdiff or not self.FileTypeSupportedByImgdiff(name):
517      return False
518
519    if not tgt_ranges.monotonic or not src_ranges.monotonic:
520      self.imgdiff_stats.Log(name, ImgdiffStats.SKIPPED_NONMONOTONIC)
521      return False
522
523    if (tgt_ranges.extra.get('uses_shared_blocks') or
524        src_ranges.extra.get('uses_shared_blocks')):
525      self.imgdiff_stats.Log(name, ImgdiffStats.SKIPPED_SHARED_BLOCKS)
526      return False
527
528    if tgt_ranges.extra.get('incomplete') or src_ranges.extra.get('incomplete'):
529      self.imgdiff_stats.Log(name, ImgdiffStats.SKIPPED_INCOMPLETE)
530      return False
531
532    reason = (ImgdiffStats.USED_IMGDIFF_LARGE_APK if large_apk
533              else ImgdiffStats.USED_IMGDIFF)
534    self.imgdiff_stats.Log(name, reason)
535    return True
536
537  def Compute(self, prefix):
538    # When looking for a source file to use as the diff input for a
539    # target file, we try:
540    #   1) an exact path match if available, otherwise
541    #   2) a exact basename match if available, otherwise
542    #   3) a basename match after all runs of digits are replaced by
543    #      "#" if available, otherwise
544    #   4) we have no source for this target.
545    self.AbbreviateSourceNames()
546    self.FindTransfers()
547
548    self.FindSequenceForTransfers()
549
550    # Ensure the runtime stash size is under the limit.
551    if common.OPTIONS.cache_size is not None:
552      stash_limit = (common.OPTIONS.cache_size *
553                     common.OPTIONS.stash_threshold / self.tgt.blocksize)
554      # Ignore the stash limit and calculate the maximum simultaneously stashed
555      # blocks needed.
556      _, max_stashed_blocks = self.ReviseStashSize(ignore_stash_limit=True)
557
558      # We cannot stash more blocks than the stash limit simultaneously. As a
559      # result, some 'diff' commands will be converted to new; leading to an
560      # unintended large package. To mitigate this issue, we can carefully
561      # choose the transfers for conversion. The number '1024' can be further
562      # tweaked here to balance the package size and build time.
563      if max_stashed_blocks > stash_limit + 1024:
564        self.SelectAndConvertDiffTransfersToNew(
565            max_stashed_blocks - stash_limit)
566        # Regenerate the sequence as the graph has changed.
567        self.FindSequenceForTransfers()
568
569      # Revise the stash size again to keep the size under limit.
570      self.ReviseStashSize()
571
572    # Double-check our work.
573    self.AssertSequenceGood()
574    self.AssertSha1Good()
575
576    self.ComputePatches(prefix)
577    self.WriteTransfers(prefix)
578
579    # Report the imgdiff stats.
580    if not self.disable_imgdiff:
581      self.imgdiff_stats.Report()
582
583  def WriteTransfers(self, prefix):
584    def WriteSplitTransfers(out, style, target_blocks):
585      """Limit the size of operand in command 'new' and 'zero' to 1024 blocks.
586
587      This prevents the target size of one command from being too large; and
588      might help to avoid fsync errors on some devices."""
589
590      assert style == "new" or style == "zero"
591      blocks_limit = 1024
592      total = 0
593      while target_blocks:
594        blocks_to_write = target_blocks.first(blocks_limit)
595        out.append("%s %s\n" % (style, blocks_to_write.to_string_raw()))
596        total += blocks_to_write.size()
597        target_blocks = target_blocks.subtract(blocks_to_write)
598      return total
599
600    out = []
601    total = 0
602
603    # In BBOTA v3+, it uses the hash of the stashed blocks as the stash slot
604    # id. 'stashes' records the map from 'hash' to the ref count. The stash
605    # will be freed only if the count decrements to zero.
606    stashes = {}
607    stashed_blocks = 0
608    max_stashed_blocks = 0
609
610    for xf in self.transfers:
611
612      for _, sr in xf.stash_before:
613        sh = self.src.RangeSha1(sr)
614        if sh in stashes:
615          stashes[sh] += 1
616        else:
617          stashes[sh] = 1
618          stashed_blocks += sr.size()
619          self.touched_src_ranges = self.touched_src_ranges.union(sr)
620          out.append("stash %s %s\n" % (sh, sr.to_string_raw()))
621
622      if stashed_blocks > max_stashed_blocks:
623        max_stashed_blocks = stashed_blocks
624
625      free_string = []
626      free_size = 0
627
628      #   <# blocks> <src ranges>
629      #     OR
630      #   <# blocks> <src ranges> <src locs> <stash refs...>
631      #     OR
632      #   <# blocks> - <stash refs...>
633
634      size = xf.src_ranges.size()
635      src_str_buffer = [str(size)]
636
637      unstashed_src_ranges = xf.src_ranges
638      mapped_stashes = []
639      for _, sr in xf.use_stash:
640        unstashed_src_ranges = unstashed_src_ranges.subtract(sr)
641        sh = self.src.RangeSha1(sr)
642        sr = xf.src_ranges.map_within(sr)
643        mapped_stashes.append(sr)
644        assert sh in stashes
645        src_str_buffer.append("%s:%s" % (sh, sr.to_string_raw()))
646        stashes[sh] -= 1
647        if stashes[sh] == 0:
648          free_string.append("free %s\n" % (sh,))
649          free_size += sr.size()
650          stashes.pop(sh)
651
652      if unstashed_src_ranges:
653        src_str_buffer.insert(1, unstashed_src_ranges.to_string_raw())
654        if xf.use_stash:
655          mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges)
656          src_str_buffer.insert(2, mapped_unstashed.to_string_raw())
657          mapped_stashes.append(mapped_unstashed)
658          self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
659      else:
660        src_str_buffer.insert(1, "-")
661        self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
662
663      src_str = " ".join(src_str_buffer)
664
665      # version 3+:
666      #   zero <rangeset>
667      #   new <rangeset>
668      #   erase <rangeset>
669      #   bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
670      #   imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
671      #   move hash <tgt rangeset> <src_str>
672
673      tgt_size = xf.tgt_ranges.size()
674
675      if xf.style == "new":
676        assert xf.tgt_ranges
677        assert tgt_size == WriteSplitTransfers(out, xf.style, xf.tgt_ranges)
678        total += tgt_size
679      elif xf.style == "move":
680        assert xf.tgt_ranges
681        assert xf.src_ranges.size() == tgt_size
682        if xf.src_ranges != xf.tgt_ranges:
683          # take into account automatic stashing of overlapping blocks
684          if xf.src_ranges.overlaps(xf.tgt_ranges):
685            temp_stash_usage = stashed_blocks + xf.src_ranges.size()
686            if temp_stash_usage > max_stashed_blocks:
687              max_stashed_blocks = temp_stash_usage
688
689          self.touched_src_ranges = self.touched_src_ranges.union(
690              xf.src_ranges)
691
692          out.append("%s %s %s %s\n" % (
693              xf.style,
694              xf.tgt_sha1,
695              xf.tgt_ranges.to_string_raw(), src_str))
696          total += tgt_size
697      elif xf.style in ("bsdiff", "imgdiff"):
698        assert xf.tgt_ranges
699        assert xf.src_ranges
700        # take into account automatic stashing of overlapping blocks
701        if xf.src_ranges.overlaps(xf.tgt_ranges):
702          temp_stash_usage = stashed_blocks + xf.src_ranges.size()
703          if temp_stash_usage > max_stashed_blocks:
704            max_stashed_blocks = temp_stash_usage
705
706        self.touched_src_ranges = self.touched_src_ranges.union(xf.src_ranges)
707
708        out.append("%s %d %d %s %s %s %s\n" % (
709            xf.style,
710            xf.patch_start, xf.patch_len,
711            xf.src_sha1,
712            xf.tgt_sha1,
713            xf.tgt_ranges.to_string_raw(), src_str))
714        total += tgt_size
715      elif xf.style == "zero":
716        assert xf.tgt_ranges
717        to_zero = xf.tgt_ranges.subtract(xf.src_ranges)
718        assert WriteSplitTransfers(out, xf.style, to_zero) == to_zero.size()
719        total += to_zero.size()
720      else:
721        raise ValueError("unknown transfer style '%s'\n" % xf.style)
722
723      if free_string:
724        out.append("".join(free_string))
725        stashed_blocks -= free_size
726
727      if common.OPTIONS.cache_size is not None:
728        # Sanity check: abort if we're going to need more stash space than
729        # the allowed size (cache_size * threshold). There are two purposes
730        # of having a threshold here. a) Part of the cache may have been
731        # occupied by some recovery logs. b) It will buy us some time to deal
732        # with the oversize issue.
733        cache_size = common.OPTIONS.cache_size
734        stash_threshold = common.OPTIONS.stash_threshold
735        max_allowed = cache_size * stash_threshold
736        assert max_stashed_blocks * self.tgt.blocksize <= max_allowed, \
737               'Stash size %d (%d * %d) exceeds the limit %d (%d * %.2f)' % (
738                   max_stashed_blocks * self.tgt.blocksize, max_stashed_blocks,
739                   self.tgt.blocksize, max_allowed, cache_size,
740                   stash_threshold)
741
742    self.touched_src_sha1 = self.src.RangeSha1(self.touched_src_ranges)
743
744    if self.tgt.hashtree_info:
745      out.append("compute_hash_tree {} {} {} {} {}\n".format(
746          self.tgt.hashtree_info.hashtree_range.to_string_raw(),
747          self.tgt.hashtree_info.filesystem_range.to_string_raw(),
748          self.tgt.hashtree_info.hash_algorithm,
749          self.tgt.hashtree_info.salt,
750          self.tgt.hashtree_info.root_hash))
751
752    # Zero out extended blocks as a workaround for bug 20881595.
753    if self.tgt.extended:
754      assert (WriteSplitTransfers(out, "zero", self.tgt.extended) ==
755              self.tgt.extended.size())
756      total += self.tgt.extended.size()
757
758    # We erase all the blocks on the partition that a) don't contain useful
759    # data in the new image; b) will not be touched by dm-verity. Out of those
760    # blocks, we erase the ones that won't be used in this update at the
761    # beginning of an update. The rest would be erased at the end. This is to
762    # work around the eMMC issue observed on some devices, which may otherwise
763    # get starving for clean blocks and thus fail the update. (b/28347095)
764    all_tgt = RangeSet(data=(0, self.tgt.total_blocks))
765    all_tgt_minus_extended = all_tgt.subtract(self.tgt.extended)
766    new_dontcare = all_tgt_minus_extended.subtract(self.tgt.care_map)
767
768    erase_first = new_dontcare.subtract(self.touched_src_ranges)
769    if erase_first:
770      out.insert(0, "erase %s\n" % (erase_first.to_string_raw(),))
771
772    erase_last = new_dontcare.subtract(erase_first)
773    if erase_last:
774      out.append("erase %s\n" % (erase_last.to_string_raw(),))
775
776    out.insert(0, "%d\n" % (self.version,))   # format version number
777    out.insert(1, "%d\n" % (total,))
778    # v3+: the number of stash slots is unused.
779    out.insert(2, "0\n")
780    out.insert(3, str(max_stashed_blocks) + "\n")
781
782    with open(prefix + ".transfer.list", "wb") as f:
783      for i in out:
784        f.write(i)
785
786    self._max_stashed_size = max_stashed_blocks * self.tgt.blocksize
787    OPTIONS = common.OPTIONS
788    if OPTIONS.cache_size is not None:
789      max_allowed = OPTIONS.cache_size * OPTIONS.stash_threshold
790      logger.info(
791          "max stashed blocks: %d  (%d bytes), limit: %d bytes (%.2f%%)\n",
792          max_stashed_blocks, self._max_stashed_size, max_allowed,
793          self._max_stashed_size * 100.0 / max_allowed)
794    else:
795      logger.info(
796          "max stashed blocks: %d  (%d bytes), limit: <unknown>\n",
797          max_stashed_blocks, self._max_stashed_size)
798
799  def ReviseStashSize(self, ignore_stash_limit=False):
800    """ Revises the transfers to keep the stash size within the size limit.
801
802    Iterates through the transfer list and calculates the stash size each
803    transfer generates. Converts the affected transfers to new if we reach the
804    stash limit.
805
806    Args:
807      ignore_stash_limit: Ignores the stash limit and calculates the max
808      simultaneous stashed blocks instead. No change will be made to the
809      transfer list with this flag.
810
811    Return:
812      A tuple of (tgt blocks converted to new, max stashed blocks)
813    """
814    logger.info("Revising stash size...")
815    stash_map = {}
816
817    # Create the map between a stash and its def/use points. For example, for a
818    # given stash of (raw_id, sr), stash_map[raw_id] = (sr, def_cmd, use_cmd).
819    for xf in self.transfers:
820      # Command xf defines (stores) all the stashes in stash_before.
821      for stash_raw_id, sr in xf.stash_before:
822        stash_map[stash_raw_id] = (sr, xf)
823
824      # Record all the stashes command xf uses.
825      for stash_raw_id, _ in xf.use_stash:
826        stash_map[stash_raw_id] += (xf,)
827
828    max_allowed_blocks = None
829    if not ignore_stash_limit:
830      # Compute the maximum blocks available for stash based on /cache size and
831      # the threshold.
832      cache_size = common.OPTIONS.cache_size
833      stash_threshold = common.OPTIONS.stash_threshold
834      max_allowed_blocks = cache_size * stash_threshold / self.tgt.blocksize
835
836    # See the comments for 'stashes' in WriteTransfers().
837    stashes = {}
838    stashed_blocks = 0
839    new_blocks = 0
840    max_stashed_blocks = 0
841
842    # Now go through all the commands. Compute the required stash size on the
843    # fly. If a command requires excess stash than available, it deletes the
844    # stash by replacing the command that uses the stash with a "new" command
845    # instead.
846    for xf in self.transfers:
847      replaced_cmds = []
848
849      # xf.stash_before generates explicit stash commands.
850      for stash_raw_id, sr in xf.stash_before:
851        # Check the post-command stashed_blocks.
852        stashed_blocks_after = stashed_blocks
853        sh = self.src.RangeSha1(sr)
854        if sh not in stashes:
855          stashed_blocks_after += sr.size()
856
857        if max_allowed_blocks and stashed_blocks_after > max_allowed_blocks:
858          # We cannot stash this one for a later command. Find out the command
859          # that will use this stash and replace the command with "new".
860          use_cmd = stash_map[stash_raw_id][2]
861          replaced_cmds.append(use_cmd)
862          logger.info("%10d  %9s  %s", sr.size(), "explicit", use_cmd)
863        else:
864          # Update the stashes map.
865          if sh in stashes:
866            stashes[sh] += 1
867          else:
868            stashes[sh] = 1
869          stashed_blocks = stashed_blocks_after
870          max_stashed_blocks = max(max_stashed_blocks, stashed_blocks)
871
872      # "move" and "diff" may introduce implicit stashes in BBOTA v3. Prior to
873      # ComputePatches(), they both have the style of "diff".
874      if xf.style == "diff":
875        assert xf.tgt_ranges and xf.src_ranges
876        if xf.src_ranges.overlaps(xf.tgt_ranges):
877          if (max_allowed_blocks and
878              stashed_blocks + xf.src_ranges.size() > max_allowed_blocks):
879            replaced_cmds.append(xf)
880            logger.info("%10d  %9s  %s", xf.src_ranges.size(), "implicit", xf)
881          else:
882            # The whole source ranges will be stashed for implicit stashes.
883            max_stashed_blocks = max(max_stashed_blocks,
884                                     stashed_blocks + xf.src_ranges.size())
885
886      # Replace the commands in replaced_cmds with "new"s.
887      for cmd in replaced_cmds:
888        # It no longer uses any commands in "use_stash". Remove the def points
889        # for all those stashes.
890        for stash_raw_id, sr in cmd.use_stash:
891          def_cmd = stash_map[stash_raw_id][1]
892          assert (stash_raw_id, sr) in def_cmd.stash_before
893          def_cmd.stash_before.remove((stash_raw_id, sr))
894
895        # Add up blocks that violates space limit and print total number to
896        # screen later.
897        new_blocks += cmd.tgt_ranges.size()
898        cmd.ConvertToNew()
899
900      # xf.use_stash may generate free commands.
901      for _, sr in xf.use_stash:
902        sh = self.src.RangeSha1(sr)
903        assert sh in stashes
904        stashes[sh] -= 1
905        if stashes[sh] == 0:
906          stashed_blocks -= sr.size()
907          stashes.pop(sh)
908
909    num_of_bytes = new_blocks * self.tgt.blocksize
910    logger.info(
911        "  Total %d blocks (%d bytes) are packed as new blocks due to "
912        "insufficient cache size. Maximum blocks stashed simultaneously: %d",
913        new_blocks, num_of_bytes, max_stashed_blocks)
914    return new_blocks, max_stashed_blocks
915
916  def ComputePatches(self, prefix):
917    logger.info("Reticulating splines...")
918    diff_queue = []
919    patch_num = 0
920    with open(prefix + ".new.dat", "wb") as new_f:
921      for index, xf in enumerate(self.transfers):
922        if xf.style == "zero":
923          tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
924          logger.info(
925              "%10d %10d (%6.2f%%) %7s %s %s", tgt_size, tgt_size, 100.0,
926              xf.style, xf.tgt_name, str(xf.tgt_ranges))
927
928        elif xf.style == "new":
929          self.tgt.WriteRangeDataToFd(xf.tgt_ranges, new_f)
930          tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
931          logger.info(
932              "%10d %10d (%6.2f%%) %7s %s %s", tgt_size, tgt_size, 100.0,
933              xf.style, xf.tgt_name, str(xf.tgt_ranges))
934
935        elif xf.style == "diff":
936          # We can't compare src and tgt directly because they may have
937          # the same content but be broken up into blocks differently, eg:
938          #
939          #    ["he", "llo"]  vs  ["h", "ello"]
940          #
941          # We want those to compare equal, ideally without having to
942          # actually concatenate the strings (these may be tens of
943          # megabytes).
944          if xf.src_sha1 == xf.tgt_sha1:
945            # These are identical; we don't need to generate a patch,
946            # just issue copy commands on the device.
947            xf.style = "move"
948            xf.patch_info = None
949            tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
950            if xf.src_ranges != xf.tgt_ranges:
951              logger.info(
952                  "%10d %10d (%6.2f%%) %7s %s %s (from %s)", tgt_size, tgt_size,
953                  100.0, xf.style,
954                  xf.tgt_name if xf.tgt_name == xf.src_name else (
955                      xf.tgt_name + " (from " + xf.src_name + ")"),
956                  str(xf.tgt_ranges), str(xf.src_ranges))
957          else:
958            if xf.patch_info:
959              # We have already generated the patch (e.g. during split of large
960              # APKs or reduction of stash size)
961              imgdiff = xf.patch_info.imgdiff
962            else:
963              imgdiff = self.CanUseImgdiff(
964                  xf.tgt_name, xf.tgt_ranges, xf.src_ranges)
965            xf.style = "imgdiff" if imgdiff else "bsdiff"
966            diff_queue.append((index, imgdiff, patch_num))
967            patch_num += 1
968
969        else:
970          assert False, "unknown style " + xf.style
971
972    patches = self.ComputePatchesForInputList(diff_queue, False)
973
974    offset = 0
975    with open(prefix + ".patch.dat", "wb") as patch_fd:
976      for index, patch_info, _ in patches:
977        xf = self.transfers[index]
978        xf.patch_len = len(patch_info.content)
979        xf.patch_start = offset
980        offset += xf.patch_len
981        patch_fd.write(patch_info.content)
982
983        tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
984        logger.info(
985            "%10d %10d (%6.2f%%) %7s %s %s %s", xf.patch_len, tgt_size,
986            xf.patch_len * 100.0 / tgt_size, xf.style,
987            xf.tgt_name if xf.tgt_name == xf.src_name else (
988                xf.tgt_name + " (from " + xf.src_name + ")"),
989            xf.tgt_ranges, xf.src_ranges)
990
991  def AssertSha1Good(self):
992    """Check the SHA-1 of the src & tgt blocks in the transfer list.
993
994    Double check the SHA-1 value to avoid the issue in b/71908713, where
995    SparseImage.RangeSha1() messed up with the hash calculation in multi-thread
996    environment. That specific problem has been fixed by protecting the
997    underlying generator function 'SparseImage._GetRangeData()' with lock.
998    """
999    for xf in self.transfers:
1000      tgt_sha1 = self.tgt.RangeSha1(xf.tgt_ranges)
1001      assert xf.tgt_sha1 == tgt_sha1
1002      if xf.style == "diff":
1003        src_sha1 = self.src.RangeSha1(xf.src_ranges)
1004        assert xf.src_sha1 == src_sha1
1005
1006  def AssertSequenceGood(self):
1007    # Simulate the sequences of transfers we will output, and check that:
1008    # - we never read a block after writing it, and
1009    # - we write every block we care about exactly once.
1010
1011    # Start with no blocks having been touched yet.
1012    touched = array.array("B", "\0" * self.tgt.total_blocks)
1013
1014    # Imagine processing the transfers in order.
1015    for xf in self.transfers:
1016      # Check that the input blocks for this transfer haven't yet been touched.
1017
1018      x = xf.src_ranges
1019      for _, sr in xf.use_stash:
1020        x = x.subtract(sr)
1021
1022      for s, e in x:
1023        # Source image could be larger. Don't check the blocks that are in the
1024        # source image only. Since they are not in 'touched', and won't ever
1025        # be touched.
1026        for i in range(s, min(e, self.tgt.total_blocks)):
1027          assert touched[i] == 0
1028
1029      # Check that the output blocks for this transfer haven't yet
1030      # been touched, and touch all the blocks written by this
1031      # transfer.
1032      for s, e in xf.tgt_ranges:
1033        for i in range(s, e):
1034          assert touched[i] == 0
1035          touched[i] = 1
1036
1037    if self.tgt.hashtree_info:
1038      for s, e in self.tgt.hashtree_info.hashtree_range:
1039        for i in range(s, e):
1040          assert touched[i] == 0
1041          touched[i] = 1
1042
1043    # Check that we've written every target block.
1044    for s, e in self.tgt.care_map:
1045      for i in range(s, e):
1046        assert touched[i] == 1
1047
1048  def FindSequenceForTransfers(self):
1049    """Finds a sequence for the given transfers.
1050
1051     The goal is to minimize the violation of order dependencies between these
1052     transfers, so that fewer blocks are stashed when applying the update.
1053    """
1054
1055    # Clear the existing dependency between transfers
1056    for xf in self.transfers:
1057      xf.goes_before = OrderedDict()
1058      xf.goes_after = OrderedDict()
1059
1060      xf.stash_before = []
1061      xf.use_stash = []
1062
1063    # Find the ordering dependencies among transfers (this is O(n^2)
1064    # in the number of transfers).
1065    self.GenerateDigraph()
1066    # Find a sequence of transfers that satisfies as many ordering
1067    # dependencies as possible (heuristically).
1068    self.FindVertexSequence()
1069    # Fix up the ordering dependencies that the sequence didn't
1070    # satisfy.
1071    self.ReverseBackwardEdges()
1072    self.ImproveVertexSequence()
1073
1074  def ImproveVertexSequence(self):
1075    logger.info("Improving vertex order...")
1076
1077    # At this point our digraph is acyclic; we reversed any edges that
1078    # were backwards in the heuristically-generated sequence.  The
1079    # previously-generated order is still acceptable, but we hope to
1080    # find a better order that needs less memory for stashed data.
1081    # Now we do a topological sort to generate a new vertex order,
1082    # using a greedy algorithm to choose which vertex goes next
1083    # whenever we have a choice.
1084
1085    # Make a copy of the edge set; this copy will get destroyed by the
1086    # algorithm.
1087    for xf in self.transfers:
1088      xf.incoming = xf.goes_after.copy()
1089      xf.outgoing = xf.goes_before.copy()
1090
1091    L = []   # the new vertex order
1092
1093    # S is the set of sources in the remaining graph; we always choose
1094    # the one that leaves the least amount of stashed data after it's
1095    # executed.
1096    S = [(u.NetStashChange(), u.order, u) for u in self.transfers
1097         if not u.incoming]
1098    heapq.heapify(S)
1099
1100    while S:
1101      _, _, xf = heapq.heappop(S)
1102      L.append(xf)
1103      for u in xf.outgoing:
1104        del u.incoming[xf]
1105        if not u.incoming:
1106          heapq.heappush(S, (u.NetStashChange(), u.order, u))
1107
1108    # if this fails then our graph had a cycle.
1109    assert len(L) == len(self.transfers)
1110
1111    self.transfers = L
1112    for i, xf in enumerate(L):
1113      xf.order = i
1114
1115  def ReverseBackwardEdges(self):
1116    """Reverse unsatisfying edges and compute pairs of stashed blocks.
1117
1118    For each transfer, make sure it properly stashes the blocks it touches and
1119    will be used by later transfers. It uses pairs of (stash_raw_id, range) to
1120    record the blocks to be stashed. 'stash_raw_id' is an id that uniquely
1121    identifies each pair. Note that for the same range (e.g. RangeSet("1-5")),
1122    it is possible to have multiple pairs with different 'stash_raw_id's. Each
1123    'stash_raw_id' will be consumed by one transfer. In BBOTA v3+, identical
1124    blocks will be written to the same stash slot in WriteTransfers().
1125    """
1126
1127    logger.info("Reversing backward edges...")
1128    in_order = 0
1129    out_of_order = 0
1130    stash_raw_id = 0
1131    stash_size = 0
1132
1133    for xf in self.transfers:
1134      for u in xf.goes_before.copy():
1135        # xf should go before u
1136        if xf.order < u.order:
1137          # it does, hurray!
1138          in_order += 1
1139        else:
1140          # it doesn't, boo.  modify u to stash the blocks that it
1141          # writes that xf wants to read, and then require u to go
1142          # before xf.
1143          out_of_order += 1
1144
1145          overlap = xf.src_ranges.intersect(u.tgt_ranges)
1146          assert overlap
1147
1148          u.stash_before.append((stash_raw_id, overlap))
1149          xf.use_stash.append((stash_raw_id, overlap))
1150          stash_raw_id += 1
1151          stash_size += overlap.size()
1152
1153          # reverse the edge direction; now xf must go after u
1154          del xf.goes_before[u]
1155          del u.goes_after[xf]
1156          xf.goes_after[u] = None    # value doesn't matter
1157          u.goes_before[xf] = None
1158
1159    logger.info(
1160        "  %d/%d dependencies (%.2f%%) were violated; %d source blocks "
1161        "stashed.", out_of_order, in_order + out_of_order,
1162        (out_of_order * 100.0 / (in_order + out_of_order)) if (
1163            in_order + out_of_order) else 0.0,
1164        stash_size)
1165
1166  def FindVertexSequence(self):
1167    logger.info("Finding vertex sequence...")
1168
1169    # This is based on "A Fast & Effective Heuristic for the Feedback
1170    # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth.  Think of
1171    # it as starting with the digraph G and moving all the vertices to
1172    # be on a horizontal line in some order, trying to minimize the
1173    # number of edges that end up pointing to the left.  Left-pointing
1174    # edges will get removed to turn the digraph into a DAG.  In this
1175    # case each edge has a weight which is the number of source blocks
1176    # we'll lose if that edge is removed; we try to minimize the total
1177    # weight rather than just the number of edges.
1178
1179    # Make a copy of the edge set; this copy will get destroyed by the
1180    # algorithm.
1181    for xf in self.transfers:
1182      xf.incoming = xf.goes_after.copy()
1183      xf.outgoing = xf.goes_before.copy()
1184      xf.score = sum(xf.outgoing.values()) - sum(xf.incoming.values())
1185
1186    # We use an OrderedDict instead of just a set so that the output
1187    # is repeatable; otherwise it would depend on the hash values of
1188    # the transfer objects.
1189    G = OrderedDict()
1190    for xf in self.transfers:
1191      G[xf] = None
1192    s1 = deque()  # the left side of the sequence, built from left to right
1193    s2 = deque()  # the right side of the sequence, built from right to left
1194
1195    heap = []
1196    for xf in self.transfers:
1197      xf.heap_item = HeapItem(xf)
1198      heap.append(xf.heap_item)
1199    heapq.heapify(heap)
1200
1201    # Use OrderedDict() instead of set() to preserve the insertion order. Need
1202    # to use 'sinks[key] = None' to add key into the set. sinks will look like
1203    # { key1: None, key2: None, ... }.
1204    sinks = OrderedDict.fromkeys(u for u in G if not u.outgoing)
1205    sources = OrderedDict.fromkeys(u for u in G if not u.incoming)
1206
1207    def adjust_score(iu, delta):
1208      iu.score += delta
1209      iu.heap_item.clear()
1210      iu.heap_item = HeapItem(iu)
1211      heapq.heappush(heap, iu.heap_item)
1212
1213    while G:
1214      # Put all sinks at the end of the sequence.
1215      while sinks:
1216        new_sinks = OrderedDict()
1217        for u in sinks:
1218          if u not in G:
1219            continue
1220          s2.appendleft(u)
1221          del G[u]
1222          for iu in u.incoming:
1223            adjust_score(iu, -iu.outgoing.pop(u))
1224            if not iu.outgoing:
1225              new_sinks[iu] = None
1226        sinks = new_sinks
1227
1228      # Put all the sources at the beginning of the sequence.
1229      while sources:
1230        new_sources = OrderedDict()
1231        for u in sources:
1232          if u not in G:
1233            continue
1234          s1.append(u)
1235          del G[u]
1236          for iu in u.outgoing:
1237            adjust_score(iu, +iu.incoming.pop(u))
1238            if not iu.incoming:
1239              new_sources[iu] = None
1240        sources = new_sources
1241
1242      if not G:
1243        break
1244
1245      # Find the "best" vertex to put next.  "Best" is the one that
1246      # maximizes the net difference in source blocks saved we get by
1247      # pretending it's a source rather than a sink.
1248
1249      while True:
1250        u = heapq.heappop(heap)
1251        if u and u.item in G:
1252          u = u.item
1253          break
1254
1255      s1.append(u)
1256      del G[u]
1257      for iu in u.outgoing:
1258        adjust_score(iu, +iu.incoming.pop(u))
1259        if not iu.incoming:
1260          sources[iu] = None
1261
1262      for iu in u.incoming:
1263        adjust_score(iu, -iu.outgoing.pop(u))
1264        if not iu.outgoing:
1265          sinks[iu] = None
1266
1267    # Now record the sequence in the 'order' field of each transfer,
1268    # and by rearranging self.transfers to be in the chosen sequence.
1269
1270    new_transfers = []
1271    for x in itertools.chain(s1, s2):
1272      x.order = len(new_transfers)
1273      new_transfers.append(x)
1274      del x.incoming
1275      del x.outgoing
1276
1277    self.transfers = new_transfers
1278
1279  def GenerateDigraph(self):
1280    logger.info("Generating digraph...")
1281
1282    # Each item of source_ranges will be:
1283    #   - None, if that block is not used as a source,
1284    #   - an ordered set of transfers.
1285    source_ranges = []
1286    for b in self.transfers:
1287      for s, e in b.src_ranges:
1288        if e > len(source_ranges):
1289          source_ranges.extend([None] * (e-len(source_ranges)))
1290        for i in range(s, e):
1291          if source_ranges[i] is None:
1292            source_ranges[i] = OrderedDict.fromkeys([b])
1293          else:
1294            source_ranges[i][b] = None
1295
1296    for a in self.transfers:
1297      intersections = OrderedDict()
1298      for s, e in a.tgt_ranges:
1299        for i in range(s, e):
1300          if i >= len(source_ranges):
1301            break
1302          # Add all the Transfers in source_ranges[i] to the (ordered) set.
1303          if source_ranges[i] is not None:
1304            for j in source_ranges[i]:
1305              intersections[j] = None
1306
1307      for b in intersections:
1308        if a is b:
1309          continue
1310
1311        # If the blocks written by A are read by B, then B needs to go before A.
1312        i = a.tgt_ranges.intersect(b.src_ranges)
1313        if i:
1314          if b.src_name == "__ZERO":
1315            # the cost of removing source blocks for the __ZERO domain
1316            # is (nearly) zero.
1317            size = 0
1318          else:
1319            size = i.size()
1320          b.goes_before[a] = size
1321          a.goes_after[b] = size
1322
1323  def ComputePatchesForInputList(self, diff_queue, compress_target):
1324    """Returns a list of patch information for the input list of transfers.
1325
1326      Args:
1327        diff_queue: a list of transfers with style 'diff'
1328        compress_target: If True, compresses the target ranges of each
1329            transfers; and save the size.
1330
1331      Returns:
1332        A list of (transfer order, patch_info, compressed_size) tuples.
1333    """
1334
1335    if not diff_queue:
1336      return []
1337
1338    if self.threads > 1:
1339      logger.info("Computing patches (using %d threads)...", self.threads)
1340    else:
1341      logger.info("Computing patches...")
1342
1343    diff_total = len(diff_queue)
1344    patches = [None] * diff_total
1345    error_messages = []
1346
1347    # Using multiprocessing doesn't give additional benefits, due to the
1348    # pattern of the code. The diffing work is done by subprocess.call, which
1349    # already runs in a separate process (not affected much by the GIL -
1350    # Global Interpreter Lock). Using multiprocess also requires either a)
1351    # writing the diff input files in the main process before forking, or b)
1352    # reopening the image file (SparseImage) in the worker processes. Doing
1353    # neither of them further improves the performance.
1354    lock = threading.Lock()
1355
1356    def diff_worker():
1357      while True:
1358        with lock:
1359          if not diff_queue:
1360            return
1361          xf_index, imgdiff, patch_index = diff_queue.pop()
1362          xf = self.transfers[xf_index]
1363
1364        message = []
1365        compressed_size = None
1366
1367        patch_info = xf.patch_info
1368        if not patch_info:
1369          src_file = common.MakeTempFile(prefix="src-")
1370          with open(src_file, "wb") as fd:
1371            self.src.WriteRangeDataToFd(xf.src_ranges, fd)
1372
1373          tgt_file = common.MakeTempFile(prefix="tgt-")
1374          with open(tgt_file, "wb") as fd:
1375            self.tgt.WriteRangeDataToFd(xf.tgt_ranges, fd)
1376
1377          try:
1378            patch_info = compute_patch(src_file, tgt_file, imgdiff)
1379          except ValueError as e:
1380            message.append(
1381                "Failed to generate %s for %s: tgt=%s, src=%s:\n%s" % (
1382                    "imgdiff" if imgdiff else "bsdiff",
1383                    xf.tgt_name if xf.tgt_name == xf.src_name else
1384                    xf.tgt_name + " (from " + xf.src_name + ")",
1385                    xf.tgt_ranges, xf.src_ranges, e.message))
1386
1387        if compress_target:
1388          tgt_data = self.tgt.ReadRangeSet(xf.tgt_ranges)
1389          try:
1390            # Compresses with the default level
1391            compress_obj = zlib.compressobj(6, zlib.DEFLATED, -zlib.MAX_WBITS)
1392            compressed_data = (compress_obj.compress("".join(tgt_data))
1393                               + compress_obj.flush())
1394            compressed_size = len(compressed_data)
1395          except zlib.error as e:
1396            message.append(
1397                "Failed to compress the data in target range {} for {}:\n"
1398                "{}".format(xf.tgt_ranges, xf.tgt_name, e.message))
1399
1400        if message:
1401          with lock:
1402            error_messages.extend(message)
1403
1404        with lock:
1405          patches[patch_index] = (xf_index, patch_info, compressed_size)
1406
1407    threads = [threading.Thread(target=diff_worker)
1408               for _ in range(self.threads)]
1409    for th in threads:
1410      th.start()
1411    while threads:
1412      threads.pop().join()
1413
1414    if error_messages:
1415      logger.error('ERROR:')
1416      logger.error('\n'.join(error_messages))
1417      logger.error('\n\n\n')
1418      sys.exit(1)
1419
1420    return patches
1421
1422  def SelectAndConvertDiffTransfersToNew(self, violated_stash_blocks):
1423    """Converts the diff transfers to reduce the max simultaneous stash.
1424
1425    Since the 'new' data is compressed with deflate, we can select the 'diff'
1426    transfers for conversion by comparing its patch size with the size of the
1427    compressed data. Ideally, we want to convert the transfers with a small
1428    size increase, but using a large number of stashed blocks.
1429    """
1430    TransferSizeScore = namedtuple("TransferSizeScore",
1431                                   "xf, used_stash_blocks, score")
1432
1433    logger.info("Selecting diff commands to convert to new.")
1434    diff_queue = []
1435    for xf in self.transfers:
1436      if xf.style == "diff" and xf.src_sha1 != xf.tgt_sha1:
1437        use_imgdiff = self.CanUseImgdiff(xf.tgt_name, xf.tgt_ranges,
1438                                         xf.src_ranges)
1439        diff_queue.append((xf.order, use_imgdiff, len(diff_queue)))
1440
1441    # Remove the 'move' transfers, and compute the patch & compressed size
1442    # for the remaining.
1443    result = self.ComputePatchesForInputList(diff_queue, True)
1444
1445    conversion_candidates = []
1446    for xf_index, patch_info, compressed_size in result:
1447      xf = self.transfers[xf_index]
1448      if not xf.patch_info:
1449        xf.patch_info = patch_info
1450
1451      size_ratio = len(xf.patch_info.content) * 100.0 / compressed_size
1452      diff_style = "imgdiff" if xf.patch_info.imgdiff else "bsdiff"
1453      logger.info("%s, target size: %d blocks, style: %s, patch size: %d,"
1454                  " compression_size: %d, ratio %.2f%%", xf.tgt_name,
1455                  xf.tgt_ranges.size(), diff_style,
1456                  len(xf.patch_info.content), compressed_size, size_ratio)
1457
1458      used_stash_blocks = sum(sr.size() for _, sr in xf.use_stash)
1459      # Convert the transfer to new if the compressed size is smaller or equal.
1460      # We don't need to maintain the stash_before lists here because the
1461      # graph will be regenerated later.
1462      if len(xf.patch_info.content) >= compressed_size:
1463        # Add the transfer to the candidate list with negative score. And it
1464        # will be converted later.
1465        conversion_candidates.append(TransferSizeScore(xf, used_stash_blocks,
1466                                                       -1))
1467      elif used_stash_blocks > 0:
1468        # This heuristic represents the size increase in the final package to
1469        # remove per unit of stashed data.
1470        score = ((compressed_size - len(xf.patch_info.content)) * 100.0
1471                 / used_stash_blocks)
1472        conversion_candidates.append(TransferSizeScore(xf, used_stash_blocks,
1473                                                       score))
1474    # Transfers with lower score (i.e. less expensive to convert) will be
1475    # converted first.
1476    conversion_candidates.sort(key=lambda x: x.score)
1477
1478    # TODO(xunchang), improve the logic to find the transfers to convert, e.g.
1479    # convert the ones that contribute to the max stash, run ReviseStashSize
1480    # multiple times etc.
1481    removed_stashed_blocks = 0
1482    for xf, used_stash_blocks, _ in conversion_candidates:
1483      logger.info("Converting %s to new", xf.tgt_name)
1484      xf.ConvertToNew()
1485      removed_stashed_blocks += used_stash_blocks
1486      # Experiments show that we will get a smaller package size if we remove
1487      # slightly more stashed blocks than the violated stash blocks.
1488      if removed_stashed_blocks >= violated_stash_blocks:
1489        break
1490
1491    logger.info("Removed %d stashed blocks", removed_stashed_blocks)
1492
1493  def FindTransfers(self):
1494    """Parse the file_map to generate all the transfers."""
1495
1496    def AddSplitTransfersWithFixedSizeChunks(tgt_name, src_name, tgt_ranges,
1497                                             src_ranges, style, by_id):
1498      """Add one or multiple Transfer()s by splitting large files.
1499
1500      For BBOTA v3, we need to stash source blocks for resumable feature.
1501      However, with the growth of file size and the shrink of the cache
1502      partition source blocks are too large to be stashed. If a file occupies
1503      too many blocks, we split it into smaller pieces by getting multiple
1504      Transfer()s.
1505
1506      The downside is that after splitting, we may increase the package size
1507      since the split pieces don't align well. According to our experiments,
1508      1/8 of the cache size as the per-piece limit appears to be optimal.
1509      Compared to the fixed 1024-block limit, it reduces the overall package
1510      size by 30% for volantis, and 20% for angler and bullhead."""
1511
1512      pieces = 0
1513      while (tgt_ranges.size() > max_blocks_per_transfer and
1514             src_ranges.size() > max_blocks_per_transfer):
1515        tgt_split_name = "%s-%d" % (tgt_name, pieces)
1516        src_split_name = "%s-%d" % (src_name, pieces)
1517        tgt_first = tgt_ranges.first(max_blocks_per_transfer)
1518        src_first = src_ranges.first(max_blocks_per_transfer)
1519
1520        Transfer(tgt_split_name, src_split_name, tgt_first, src_first,
1521                 self.tgt.RangeSha1(tgt_first), self.src.RangeSha1(src_first),
1522                 style, by_id)
1523
1524        tgt_ranges = tgt_ranges.subtract(tgt_first)
1525        src_ranges = src_ranges.subtract(src_first)
1526        pieces += 1
1527
1528      # Handle remaining blocks.
1529      if tgt_ranges.size() or src_ranges.size():
1530        # Must be both non-empty.
1531        assert tgt_ranges.size() and src_ranges.size()
1532        tgt_split_name = "%s-%d" % (tgt_name, pieces)
1533        src_split_name = "%s-%d" % (src_name, pieces)
1534        Transfer(tgt_split_name, src_split_name, tgt_ranges, src_ranges,
1535                 self.tgt.RangeSha1(tgt_ranges), self.src.RangeSha1(src_ranges),
1536                 style, by_id)
1537
1538    def AddSplitTransfers(tgt_name, src_name, tgt_ranges, src_ranges, style,
1539                          by_id):
1540      """Find all the zip files and split the others with a fixed chunk size.
1541
1542      This function will construct a list of zip archives, which will later be
1543      split by imgdiff to reduce the final patch size. For the other files,
1544      we will plainly split them based on a fixed chunk size with the potential
1545      patch size penalty.
1546      """
1547
1548      assert style == "diff"
1549
1550      # Change nothing for small files.
1551      if (tgt_ranges.size() <= max_blocks_per_transfer and
1552          src_ranges.size() <= max_blocks_per_transfer):
1553        Transfer(tgt_name, src_name, tgt_ranges, src_ranges,
1554                 self.tgt.RangeSha1(tgt_ranges), self.src.RangeSha1(src_ranges),
1555                 style, by_id)
1556        return
1557
1558      # Split large APKs with imgdiff, if possible. We're intentionally checking
1559      # file types one more time (CanUseImgdiff() checks that as well), before
1560      # calling the costly RangeSha1()s.
1561      if (self.FileTypeSupportedByImgdiff(tgt_name) and
1562          self.tgt.RangeSha1(tgt_ranges) != self.src.RangeSha1(src_ranges)):
1563        if self.CanUseImgdiff(tgt_name, tgt_ranges, src_ranges, True):
1564          large_apks.append((tgt_name, src_name, tgt_ranges, src_ranges))
1565          return
1566
1567      AddSplitTransfersWithFixedSizeChunks(tgt_name, src_name, tgt_ranges,
1568                                           src_ranges, style, by_id)
1569
1570    def AddTransfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id,
1571                    split=False):
1572      """Wrapper function for adding a Transfer()."""
1573
1574      # We specialize diff transfers only (which covers bsdiff/imgdiff/move);
1575      # otherwise add the Transfer() as is.
1576      if style != "diff" or not split:
1577        Transfer(tgt_name, src_name, tgt_ranges, src_ranges,
1578                 self.tgt.RangeSha1(tgt_ranges), self.src.RangeSha1(src_ranges),
1579                 style, by_id)
1580        return
1581
1582      # Handle .odex files specially to analyze the block-wise difference. If
1583      # most of the blocks are identical with only few changes (e.g. header),
1584      # we will patch the changed blocks only. This avoids stashing unchanged
1585      # blocks while patching. We limit the analysis to files without size
1586      # changes only. This is to avoid sacrificing the OTA generation cost too
1587      # much.
1588      if (tgt_name.split(".")[-1].lower() == 'odex' and
1589          tgt_ranges.size() == src_ranges.size()):
1590
1591        # 0.5 threshold can be further tuned. The tradeoff is: if only very
1592        # few blocks remain identical, we lose the opportunity to use imgdiff
1593        # that may have better compression ratio than bsdiff.
1594        crop_threshold = 0.5
1595
1596        tgt_skipped = RangeSet()
1597        src_skipped = RangeSet()
1598        tgt_size = tgt_ranges.size()
1599        tgt_changed = 0
1600        for src_block, tgt_block in zip(src_ranges.next_item(),
1601                                        tgt_ranges.next_item()):
1602          src_rs = RangeSet(str(src_block))
1603          tgt_rs = RangeSet(str(tgt_block))
1604          if self.src.ReadRangeSet(src_rs) == self.tgt.ReadRangeSet(tgt_rs):
1605            tgt_skipped = tgt_skipped.union(tgt_rs)
1606            src_skipped = src_skipped.union(src_rs)
1607          else:
1608            tgt_changed += tgt_rs.size()
1609
1610          # Terminate early if no clear sign of benefits.
1611          if tgt_changed > tgt_size * crop_threshold:
1612            break
1613
1614        if tgt_changed < tgt_size * crop_threshold:
1615          assert tgt_changed + tgt_skipped.size() == tgt_size
1616          logger.info(
1617              '%10d %10d (%6.2f%%) %s', tgt_skipped.size(), tgt_size,
1618              tgt_skipped.size() * 100.0 / tgt_size, tgt_name)
1619          AddSplitTransfers(
1620              "%s-skipped" % (tgt_name,),
1621              "%s-skipped" % (src_name,),
1622              tgt_skipped, src_skipped, style, by_id)
1623
1624          # Intentionally change the file extension to avoid being imgdiff'd as
1625          # the files are no longer in their original format.
1626          tgt_name = "%s-cropped" % (tgt_name,)
1627          src_name = "%s-cropped" % (src_name,)
1628          tgt_ranges = tgt_ranges.subtract(tgt_skipped)
1629          src_ranges = src_ranges.subtract(src_skipped)
1630
1631          # Possibly having no changed blocks.
1632          if not tgt_ranges:
1633            return
1634
1635      # Add the transfer(s).
1636      AddSplitTransfers(
1637          tgt_name, src_name, tgt_ranges, src_ranges, style, by_id)
1638
1639    def ParseAndValidateSplitInfo(patch_size, tgt_ranges, src_ranges,
1640                                  split_info):
1641      """Parse the split_info and return a list of info tuples.
1642
1643      Args:
1644        patch_size: total size of the patch file.
1645        tgt_ranges: Ranges of the target file within the original image.
1646        src_ranges: Ranges of the source file within the original image.
1647        split_info format:
1648          imgdiff version#
1649          count of pieces
1650          <patch_size_1> <tgt_size_1> <src_ranges_1>
1651          ...
1652          <patch_size_n> <tgt_size_n> <src_ranges_n>
1653
1654      Returns:
1655        [patch_start, patch_len, split_tgt_ranges, split_src_ranges]
1656      """
1657
1658      version = int(split_info[0])
1659      assert version == 2
1660      count = int(split_info[1])
1661      assert len(split_info) - 2 == count
1662
1663      split_info_list = []
1664      patch_start = 0
1665      tgt_remain = copy.deepcopy(tgt_ranges)
1666      # each line has the format <patch_size>, <tgt_size>, <src_ranges>
1667      for line in split_info[2:]:
1668        info = line.split()
1669        assert len(info) == 3
1670        patch_length = int(info[0])
1671
1672        split_tgt_size = int(info[1])
1673        assert split_tgt_size % 4096 == 0
1674        assert split_tgt_size / 4096 <= tgt_remain.size()
1675        split_tgt_ranges = tgt_remain.first(split_tgt_size / 4096)
1676        tgt_remain = tgt_remain.subtract(split_tgt_ranges)
1677
1678        # Find the split_src_ranges within the image file from its relative
1679        # position in file.
1680        split_src_indices = RangeSet.parse_raw(info[2])
1681        split_src_ranges = RangeSet()
1682        for r in split_src_indices:
1683          curr_range = src_ranges.first(r[1]).subtract(src_ranges.first(r[0]))
1684          assert not split_src_ranges.overlaps(curr_range)
1685          split_src_ranges = split_src_ranges.union(curr_range)
1686
1687        split_info_list.append((patch_start, patch_length,
1688                                split_tgt_ranges, split_src_ranges))
1689        patch_start += patch_length
1690
1691      # Check that the sizes of all the split pieces add up to the final file
1692      # size for patch and target.
1693      assert tgt_remain.size() == 0
1694      assert patch_start == patch_size
1695      return split_info_list
1696
1697    def SplitLargeApks():
1698      """Split the large apks files.
1699
1700      Example: Chrome.apk will be split into
1701        src-0: Chrome.apk-0, tgt-0: Chrome.apk-0
1702        src-1: Chrome.apk-1, tgt-1: Chrome.apk-1
1703        ...
1704
1705      After the split, the target pieces are continuous and block aligned; and
1706      the source pieces are mutually exclusive. During the split, we also
1707      generate and save the image patch between src-X & tgt-X. This patch will
1708      be valid because the block ranges of src-X & tgt-X will always stay the
1709      same afterwards; but there's a chance we don't use the patch if we
1710      convert the "diff" command into "new" or "move" later.
1711      """
1712
1713      while True:
1714        with transfer_lock:
1715          if not large_apks:
1716            return
1717          tgt_name, src_name, tgt_ranges, src_ranges = large_apks.pop(0)
1718
1719        src_file = common.MakeTempFile(prefix="src-")
1720        tgt_file = common.MakeTempFile(prefix="tgt-")
1721        with open(src_file, "wb") as src_fd:
1722          self.src.WriteRangeDataToFd(src_ranges, src_fd)
1723        with open(tgt_file, "wb") as tgt_fd:
1724          self.tgt.WriteRangeDataToFd(tgt_ranges, tgt_fd)
1725
1726        patch_file = common.MakeTempFile(prefix="patch-")
1727        patch_info_file = common.MakeTempFile(prefix="split_info-")
1728        cmd = ["imgdiff", "-z",
1729               "--block-limit={}".format(max_blocks_per_transfer),
1730               "--split-info=" + patch_info_file,
1731               src_file, tgt_file, patch_file]
1732        proc = common.Run(cmd)
1733        imgdiff_output, _ = proc.communicate()
1734        assert proc.returncode == 0, \
1735            "Failed to create imgdiff patch between {} and {}:\n{}".format(
1736                src_name, tgt_name, imgdiff_output)
1737
1738        with open(patch_info_file) as patch_info:
1739          lines = patch_info.readlines()
1740
1741        patch_size_total = os.path.getsize(patch_file)
1742        split_info_list = ParseAndValidateSplitInfo(patch_size_total,
1743                                                    tgt_ranges, src_ranges,
1744                                                    lines)
1745        for index, (patch_start, patch_length, split_tgt_ranges,
1746                    split_src_ranges) in enumerate(split_info_list):
1747          with open(patch_file) as f:
1748            f.seek(patch_start)
1749            patch_content = f.read(patch_length)
1750
1751          split_src_name = "{}-{}".format(src_name, index)
1752          split_tgt_name = "{}-{}".format(tgt_name, index)
1753          split_large_apks.append((split_tgt_name,
1754                                   split_src_name,
1755                                   split_tgt_ranges,
1756                                   split_src_ranges,
1757                                   patch_content))
1758
1759    logger.info("Finding transfers...")
1760
1761    large_apks = []
1762    split_large_apks = []
1763    cache_size = common.OPTIONS.cache_size
1764    split_threshold = 0.125
1765    max_blocks_per_transfer = int(cache_size * split_threshold /
1766                                  self.tgt.blocksize)
1767    empty = RangeSet()
1768    for tgt_fn, tgt_ranges in sorted(self.tgt.file_map.items()):
1769      if tgt_fn == "__ZERO":
1770        # the special "__ZERO" domain is all the blocks not contained
1771        # in any file and that are filled with zeros.  We have a
1772        # special transfer style for zero blocks.
1773        src_ranges = self.src.file_map.get("__ZERO", empty)
1774        AddTransfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges,
1775                    "zero", self.transfers)
1776        continue
1777
1778      elif tgt_fn == "__COPY":
1779        # "__COPY" domain includes all the blocks not contained in any
1780        # file and that need to be copied unconditionally to the target.
1781        AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
1782        continue
1783
1784      elif tgt_fn == "__HASHTREE":
1785        continue
1786
1787      elif tgt_fn in self.src.file_map:
1788        # Look for an exact pathname match in the source.
1789        AddTransfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn],
1790                    "diff", self.transfers, True)
1791        continue
1792
1793      b = os.path.basename(tgt_fn)
1794      if b in self.src_basenames:
1795        # Look for an exact basename match in the source.
1796        src_fn = self.src_basenames[b]
1797        AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
1798                    "diff", self.transfers, True)
1799        continue
1800
1801      b = re.sub("[0-9]+", "#", b)
1802      if b in self.src_numpatterns:
1803        # Look for a 'number pattern' match (a basename match after
1804        # all runs of digits are replaced by "#").  (This is useful
1805        # for .so files that contain version numbers in the filename
1806        # that get bumped.)
1807        src_fn = self.src_numpatterns[b]
1808        AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
1809                    "diff", self.transfers, True)
1810        continue
1811
1812      AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
1813
1814    transfer_lock = threading.Lock()
1815    threads = [threading.Thread(target=SplitLargeApks)
1816               for _ in range(self.threads)]
1817    for th in threads:
1818      th.start()
1819    while threads:
1820      threads.pop().join()
1821
1822    # Sort the split transfers for large apks to generate a determinate package.
1823    split_large_apks.sort()
1824    for (tgt_name, src_name, tgt_ranges, src_ranges,
1825         patch) in split_large_apks:
1826      transfer_split = Transfer(tgt_name, src_name, tgt_ranges, src_ranges,
1827                                self.tgt.RangeSha1(tgt_ranges),
1828                                self.src.RangeSha1(src_ranges),
1829                                "diff", self.transfers)
1830      transfer_split.patch_info = PatchInfo(True, patch)
1831
1832  def AbbreviateSourceNames(self):
1833    for k in self.src.file_map.keys():
1834      b = os.path.basename(k)
1835      self.src_basenames[b] = k
1836      b = re.sub("[0-9]+", "#", b)
1837      self.src_numpatterns[b] = k
1838
1839  @staticmethod
1840  def AssertPartition(total, seq):
1841    """Assert that all the RangeSets in 'seq' form a partition of the
1842    'total' RangeSet (ie, they are nonintersecting and their union
1843    equals 'total')."""
1844
1845    so_far = RangeSet()
1846    for i in seq:
1847      assert not so_far.overlaps(i)
1848      so_far = so_far.union(i)
1849    assert so_far == total
1850