• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (C) 2014 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import print_function
16
17from collections import deque, OrderedDict
18from hashlib import sha1
19import array
20import common
21import functools
22import heapq
23import itertools
24import multiprocessing
25import os
26import re
27import subprocess
28import threading
29import time
30import tempfile
31
32from rangelib import RangeSet
33
34
35__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]
36
37
38def compute_patch(src, tgt, imgdiff=False):
39  srcfd, srcfile = tempfile.mkstemp(prefix="src-")
40  tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-")
41  patchfd, patchfile = tempfile.mkstemp(prefix="patch-")
42  os.close(patchfd)
43
44  try:
45    with os.fdopen(srcfd, "wb") as f_src:
46      for p in src:
47        f_src.write(p)
48
49    with os.fdopen(tgtfd, "wb") as f_tgt:
50      for p in tgt:
51        f_tgt.write(p)
52    try:
53      os.unlink(patchfile)
54    except OSError:
55      pass
56    if imgdiff:
57      p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile],
58                          stdout=open("/dev/null", "a"),
59                          stderr=subprocess.STDOUT)
60    else:
61      p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile])
62
63    if p:
64      raise ValueError("diff failed: " + str(p))
65
66    with open(patchfile, "rb") as f:
67      return f.read()
68  finally:
69    try:
70      os.unlink(srcfile)
71      os.unlink(tgtfile)
72      os.unlink(patchfile)
73    except OSError:
74      pass
75
76
77class Image(object):
78  def ReadRangeSet(self, ranges):
79    raise NotImplementedError
80
81  def TotalSha1(self, include_clobbered_blocks=False):
82    raise NotImplementedError
83
84
85class EmptyImage(Image):
86  """A zero-length image."""
87  blocksize = 4096
88  care_map = RangeSet()
89  clobbered_blocks = RangeSet()
90  extended = RangeSet()
91  total_blocks = 0
92  file_map = {}
93  def ReadRangeSet(self, ranges):
94    return ()
95  def TotalSha1(self, include_clobbered_blocks=False):
96    # EmptyImage always carries empty clobbered_blocks, so
97    # include_clobbered_blocks can be ignored.
98    assert self.clobbered_blocks.size() == 0
99    return sha1().hexdigest()
100
101
102class DataImage(Image):
103  """An image wrapped around a single string of data."""
104
105  def __init__(self, data, trim=False, pad=False):
106    self.data = data
107    self.blocksize = 4096
108
109    assert not (trim and pad)
110
111    partial = len(self.data) % self.blocksize
112    padded = False
113    if partial > 0:
114      if trim:
115        self.data = self.data[:-partial]
116      elif pad:
117        self.data += '\0' * (self.blocksize - partial)
118        padded = True
119      else:
120        raise ValueError(("data for DataImage must be multiple of %d bytes "
121                          "unless trim or pad is specified") %
122                         (self.blocksize,))
123
124    assert len(self.data) % self.blocksize == 0
125
126    self.total_blocks = len(self.data) / self.blocksize
127    self.care_map = RangeSet(data=(0, self.total_blocks))
128    # When the last block is padded, we always write the whole block even for
129    # incremental OTAs. Because otherwise the last block may get skipped if
130    # unchanged for an incremental, but would fail the post-install
131    # verification if it has non-zero contents in the padding bytes.
132    # Bug: 23828506
133    if padded:
134      clobbered_blocks = [self.total_blocks-1, self.total_blocks]
135    else:
136      clobbered_blocks = []
137    self.clobbered_blocks = clobbered_blocks
138    self.extended = RangeSet()
139
140    zero_blocks = []
141    nonzero_blocks = []
142    reference = '\0' * self.blocksize
143
144    for i in range(self.total_blocks-1 if padded else self.total_blocks):
145      d = self.data[i*self.blocksize : (i+1)*self.blocksize]
146      if d == reference:
147        zero_blocks.append(i)
148        zero_blocks.append(i+1)
149      else:
150        nonzero_blocks.append(i)
151        nonzero_blocks.append(i+1)
152
153    assert zero_blocks or nonzero_blocks or clobbered_blocks
154
155    self.file_map = dict()
156    if zero_blocks:
157      self.file_map["__ZERO"] = RangeSet(data=zero_blocks)
158    if nonzero_blocks:
159      self.file_map["__NONZERO"] = RangeSet(data=nonzero_blocks)
160    if clobbered_blocks:
161      self.file_map["__COPY"] = RangeSet(data=clobbered_blocks)
162
163  def ReadRangeSet(self, ranges):
164    return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges]
165
166  def TotalSha1(self, include_clobbered_blocks=False):
167    if not include_clobbered_blocks:
168      ranges = self.care_map.subtract(self.clobbered_blocks)
169      return sha1(self.ReadRangeSet(ranges)).hexdigest()
170    else:
171      return sha1(self.data).hexdigest()
172
173
174class Transfer(object):
175  def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id):
176    self.tgt_name = tgt_name
177    self.src_name = src_name
178    self.tgt_ranges = tgt_ranges
179    self.src_ranges = src_ranges
180    self.style = style
181    self.intact = (getattr(tgt_ranges, "monotonic", False) and
182                   getattr(src_ranges, "monotonic", False))
183
184    # We use OrderedDict rather than dict so that the output is repeatable;
185    # otherwise it would depend on the hash values of the Transfer objects.
186    self.goes_before = OrderedDict()
187    self.goes_after = OrderedDict()
188
189    self.stash_before = []
190    self.use_stash = []
191
192    self.id = len(by_id)
193    by_id.append(self)
194
195  def NetStashChange(self):
196    return (sum(sr.size() for (_, sr) in self.stash_before) -
197            sum(sr.size() for (_, sr) in self.use_stash))
198
199  def ConvertToNew(self):
200    assert self.style != "new"
201    self.use_stash = []
202    self.style = "new"
203    self.src_ranges = RangeSet()
204
205  def __str__(self):
206    return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
207            " to " + str(self.tgt_ranges) + ">")
208
209
210@functools.total_ordering
211class HeapItem(object):
212  def __init__(self, item):
213    self.item = item
214    # Negate the score since python's heap is a min-heap and we want
215    # the maximum score.
216    self.score = -item.score
217  def clear(self):
218    self.item = None
219  def __bool__(self):
220    return self.item is None
221  def __eq__(self, other):
222    return self.score == other.score
223  def __le__(self, other):
224    return self.score <= other.score
225
226
227# BlockImageDiff works on two image objects.  An image object is
228# anything that provides the following attributes:
229#
230#    blocksize: the size in bytes of a block, currently must be 4096.
231#
232#    total_blocks: the total size of the partition/image, in blocks.
233#
234#    care_map: a RangeSet containing which blocks (in the range [0,
235#      total_blocks) we actually care about; i.e. which blocks contain
236#      data.
237#
238#    file_map: a dict that partitions the blocks contained in care_map
239#      into smaller domains that are useful for doing diffs on.
240#      (Typically a domain is a file, and the key in file_map is the
241#      pathname.)
242#
243#    clobbered_blocks: a RangeSet containing which blocks contain data
244#      but may be altered by the FS. They need to be excluded when
245#      verifying the partition integrity.
246#
247#    ReadRangeSet(): a function that takes a RangeSet and returns the
248#      data contained in the image blocks of that RangeSet.  The data
249#      is returned as a list or tuple of strings; concatenating the
250#      elements together should produce the requested data.
251#      Implementations are free to break up the data into list/tuple
252#      elements in any way that is convenient.
253#
254#    TotalSha1(): a function that returns (as a hex string) the SHA-1
255#      hash of all the data in the image (ie, all the blocks in the
256#      care_map minus clobbered_blocks, or including the clobbered
257#      blocks if include_clobbered_blocks is True).
258#
259# When creating a BlockImageDiff, the src image may be None, in which
260# case the list of transfers produced will never read from the
261# original image.
262
263class BlockImageDiff(object):
264  def __init__(self, tgt, src=None, threads=None, version=4,
265               disable_imgdiff=False):
266    if threads is None:
267      threads = multiprocessing.cpu_count() // 2
268      if threads == 0:
269        threads = 1
270    self.threads = threads
271    self.version = version
272    self.transfers = []
273    self.src_basenames = {}
274    self.src_numpatterns = {}
275    self._max_stashed_size = 0
276    self.touched_src_ranges = RangeSet()
277    self.touched_src_sha1 = None
278    self.disable_imgdiff = disable_imgdiff
279
280    assert version in (1, 2, 3, 4)
281
282    self.tgt = tgt
283    if src is None:
284      src = EmptyImage()
285    self.src = src
286
287    # The updater code that installs the patch always uses 4k blocks.
288    assert tgt.blocksize == 4096
289    assert src.blocksize == 4096
290
291    # The range sets in each filemap should comprise a partition of
292    # the care map.
293    self.AssertPartition(src.care_map, src.file_map.values())
294    self.AssertPartition(tgt.care_map, tgt.file_map.values())
295
296  @property
297  def max_stashed_size(self):
298    return self._max_stashed_size
299
300  def Compute(self, prefix):
301    # When looking for a source file to use as the diff input for a
302    # target file, we try:
303    #   1) an exact path match if available, otherwise
304    #   2) a exact basename match if available, otherwise
305    #   3) a basename match after all runs of digits are replaced by
306    #      "#" if available, otherwise
307    #   4) we have no source for this target.
308    self.AbbreviateSourceNames()
309    self.FindTransfers()
310
311    # Find the ordering dependencies among transfers (this is O(n^2)
312    # in the number of transfers).
313    self.GenerateDigraph()
314    # Find a sequence of transfers that satisfies as many ordering
315    # dependencies as possible (heuristically).
316    self.FindVertexSequence()
317    # Fix up the ordering dependencies that the sequence didn't
318    # satisfy.
319    if self.version == 1:
320      self.RemoveBackwardEdges()
321    else:
322      self.ReverseBackwardEdges()
323      self.ImproveVertexSequence()
324
325    # Ensure the runtime stash size is under the limit.
326    if self.version >= 2 and common.OPTIONS.cache_size is not None:
327      self.ReviseStashSize()
328
329    # Double-check our work.
330    self.AssertSequenceGood()
331
332    self.ComputePatches(prefix)
333    self.WriteTransfers(prefix)
334
335  def HashBlocks(self, source, ranges): # pylint: disable=no-self-use
336    data = source.ReadRangeSet(ranges)
337    ctx = sha1()
338
339    for p in data:
340      ctx.update(p)
341
342    return ctx.hexdigest()
343
344  def WriteTransfers(self, prefix):
345    def WriteSplitTransfers(out, style, target_blocks):
346      """Limit the size of operand in command 'new' and 'zero' to 1024 blocks.
347
348      This prevents the target size of one command from being too large; and
349      might help to avoid fsync errors on some devices."""
350
351      assert (style == "new" or style == "zero")
352      blocks_limit = 1024
353      total = 0
354      while target_blocks:
355        blocks_to_write = target_blocks.first(blocks_limit)
356        out.append("%s %s\n" % (style, blocks_to_write.to_string_raw()))
357        total += blocks_to_write.size()
358        target_blocks = target_blocks.subtract(blocks_to_write)
359      return total
360
361    out = []
362
363    total = 0
364
365    stashes = {}
366    stashed_blocks = 0
367    max_stashed_blocks = 0
368
369    free_stash_ids = []
370    next_stash_id = 0
371
372    for xf in self.transfers:
373
374      if self.version < 2:
375        assert not xf.stash_before
376        assert not xf.use_stash
377
378      for s, sr in xf.stash_before:
379        assert s not in stashes
380        if free_stash_ids:
381          sid = heapq.heappop(free_stash_ids)
382        else:
383          sid = next_stash_id
384          next_stash_id += 1
385        stashes[s] = sid
386        if self.version == 2:
387          stashed_blocks += sr.size()
388          out.append("stash %d %s\n" % (sid, sr.to_string_raw()))
389        else:
390          sh = self.HashBlocks(self.src, sr)
391          if sh in stashes:
392            stashes[sh] += 1
393          else:
394            stashes[sh] = 1
395            stashed_blocks += sr.size()
396            self.touched_src_ranges = self.touched_src_ranges.union(sr)
397            out.append("stash %s %s\n" % (sh, sr.to_string_raw()))
398
399      if stashed_blocks > max_stashed_blocks:
400        max_stashed_blocks = stashed_blocks
401
402      free_string = []
403      free_size = 0
404
405      if self.version == 1:
406        src_str = xf.src_ranges.to_string_raw() if xf.src_ranges else ""
407      elif self.version >= 2:
408
409        #   <# blocks> <src ranges>
410        #     OR
411        #   <# blocks> <src ranges> <src locs> <stash refs...>
412        #     OR
413        #   <# blocks> - <stash refs...>
414
415        size = xf.src_ranges.size()
416        src_str = [str(size)]
417
418        unstashed_src_ranges = xf.src_ranges
419        mapped_stashes = []
420        for s, sr in xf.use_stash:
421          sid = stashes.pop(s)
422          unstashed_src_ranges = unstashed_src_ranges.subtract(sr)
423          sh = self.HashBlocks(self.src, sr)
424          sr = xf.src_ranges.map_within(sr)
425          mapped_stashes.append(sr)
426          if self.version == 2:
427            src_str.append("%d:%s" % (sid, sr.to_string_raw()))
428            # A stash will be used only once. We need to free the stash
429            # immediately after the use, instead of waiting for the automatic
430            # clean-up at the end. Because otherwise it may take up extra space
431            # and lead to OTA failures.
432            # Bug: 23119955
433            free_string.append("free %d\n" % (sid,))
434            free_size += sr.size()
435          else:
436            assert sh in stashes
437            src_str.append("%s:%s" % (sh, sr.to_string_raw()))
438            stashes[sh] -= 1
439            if stashes[sh] == 0:
440              free_size += sr.size()
441              free_string.append("free %s\n" % (sh))
442              stashes.pop(sh)
443          heapq.heappush(free_stash_ids, sid)
444
445        if unstashed_src_ranges:
446          src_str.insert(1, unstashed_src_ranges.to_string_raw())
447          if xf.use_stash:
448            mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges)
449            src_str.insert(2, mapped_unstashed.to_string_raw())
450            mapped_stashes.append(mapped_unstashed)
451            self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
452        else:
453          src_str.insert(1, "-")
454          self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
455
456        src_str = " ".join(src_str)
457
458      # all versions:
459      #   zero <rangeset>
460      #   new <rangeset>
461      #   erase <rangeset>
462      #
463      # version 1:
464      #   bsdiff patchstart patchlen <src rangeset> <tgt rangeset>
465      #   imgdiff patchstart patchlen <src rangeset> <tgt rangeset>
466      #   move <src rangeset> <tgt rangeset>
467      #
468      # version 2:
469      #   bsdiff patchstart patchlen <tgt rangeset> <src_str>
470      #   imgdiff patchstart patchlen <tgt rangeset> <src_str>
471      #   move <tgt rangeset> <src_str>
472      #
473      # version 3:
474      #   bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
475      #   imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
476      #   move hash <tgt rangeset> <src_str>
477
478      tgt_size = xf.tgt_ranges.size()
479
480      if xf.style == "new":
481        assert xf.tgt_ranges
482        assert tgt_size == WriteSplitTransfers(out, xf.style, xf.tgt_ranges)
483        total += tgt_size
484      elif xf.style == "move":
485        assert xf.tgt_ranges
486        assert xf.src_ranges.size() == tgt_size
487        if xf.src_ranges != xf.tgt_ranges:
488          if self.version == 1:
489            out.append("%s %s %s\n" % (
490                xf.style,
491                xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
492          elif self.version == 2:
493            out.append("%s %s %s\n" % (
494                xf.style,
495                xf.tgt_ranges.to_string_raw(), src_str))
496          elif self.version >= 3:
497            # take into account automatic stashing of overlapping blocks
498            if xf.src_ranges.overlaps(xf.tgt_ranges):
499              temp_stash_usage = stashed_blocks + xf.src_ranges.size()
500              if temp_stash_usage > max_stashed_blocks:
501                max_stashed_blocks = temp_stash_usage
502
503            self.touched_src_ranges = self.touched_src_ranges.union(
504                xf.src_ranges)
505
506            out.append("%s %s %s %s\n" % (
507                xf.style,
508                self.HashBlocks(self.tgt, xf.tgt_ranges),
509                xf.tgt_ranges.to_string_raw(), src_str))
510          total += tgt_size
511      elif xf.style in ("bsdiff", "imgdiff"):
512        assert xf.tgt_ranges
513        assert xf.src_ranges
514        if self.version == 1:
515          out.append("%s %d %d %s %s\n" % (
516              xf.style, xf.patch_start, xf.patch_len,
517              xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
518        elif self.version == 2:
519          out.append("%s %d %d %s %s\n" % (
520              xf.style, xf.patch_start, xf.patch_len,
521              xf.tgt_ranges.to_string_raw(), src_str))
522        elif self.version >= 3:
523          # take into account automatic stashing of overlapping blocks
524          if xf.src_ranges.overlaps(xf.tgt_ranges):
525            temp_stash_usage = stashed_blocks + xf.src_ranges.size()
526            if temp_stash_usage > max_stashed_blocks:
527              max_stashed_blocks = temp_stash_usage
528
529          self.touched_src_ranges = self.touched_src_ranges.union(
530              xf.src_ranges)
531
532          out.append("%s %d %d %s %s %s %s\n" % (
533              xf.style,
534              xf.patch_start, xf.patch_len,
535              self.HashBlocks(self.src, xf.src_ranges),
536              self.HashBlocks(self.tgt, xf.tgt_ranges),
537              xf.tgt_ranges.to_string_raw(), src_str))
538        total += tgt_size
539      elif xf.style == "zero":
540        assert xf.tgt_ranges
541        to_zero = xf.tgt_ranges.subtract(xf.src_ranges)
542        assert WriteSplitTransfers(out, xf.style, to_zero) == to_zero.size()
543        total += to_zero.size()
544      else:
545        raise ValueError("unknown transfer style '%s'\n" % xf.style)
546
547      if free_string:
548        out.append("".join(free_string))
549        stashed_blocks -= free_size
550
551      if self.version >= 2 and common.OPTIONS.cache_size is not None:
552        # Sanity check: abort if we're going to need more stash space than
553        # the allowed size (cache_size * threshold). There are two purposes
554        # of having a threshold here. a) Part of the cache may have been
555        # occupied by some recovery logs. b) It will buy us some time to deal
556        # with the oversize issue.
557        cache_size = common.OPTIONS.cache_size
558        stash_threshold = common.OPTIONS.stash_threshold
559        max_allowed = cache_size * stash_threshold
560        assert max_stashed_blocks * self.tgt.blocksize < max_allowed, \
561               'Stash size %d (%d * %d) exceeds the limit %d (%d * %.2f)' % (
562                   max_stashed_blocks * self.tgt.blocksize, max_stashed_blocks,
563                   self.tgt.blocksize, max_allowed, cache_size,
564                   stash_threshold)
565
566    if self.version >= 3:
567      self.touched_src_sha1 = self.HashBlocks(
568          self.src, self.touched_src_ranges)
569
570    # Zero out extended blocks as a workaround for bug 20881595.
571    if self.tgt.extended:
572      assert (WriteSplitTransfers(out, "zero", self.tgt.extended) ==
573              self.tgt.extended.size())
574      total += self.tgt.extended.size()
575
576    # We erase all the blocks on the partition that a) don't contain useful
577    # data in the new image; b) will not be touched by dm-verity. Out of those
578    # blocks, we erase the ones that won't be used in this update at the
579    # beginning of an update. The rest would be erased at the end. This is to
580    # work around the eMMC issue observed on some devices, which may otherwise
581    # get starving for clean blocks and thus fail the update. (b/28347095)
582    all_tgt = RangeSet(data=(0, self.tgt.total_blocks))
583    all_tgt_minus_extended = all_tgt.subtract(self.tgt.extended)
584    new_dontcare = all_tgt_minus_extended.subtract(self.tgt.care_map)
585
586    erase_first = new_dontcare.subtract(self.touched_src_ranges)
587    if erase_first:
588      out.insert(0, "erase %s\n" % (erase_first.to_string_raw(),))
589
590    erase_last = new_dontcare.subtract(erase_first)
591    if erase_last:
592      out.append("erase %s\n" % (erase_last.to_string_raw(),))
593
594    out.insert(0, "%d\n" % (self.version,))   # format version number
595    out.insert(1, "%d\n" % (total,))
596    if self.version >= 2:
597      # version 2 only: after the total block count, we give the number
598      # of stash slots needed, and the maximum size needed (in blocks)
599      out.insert(2, str(next_stash_id) + "\n")
600      out.insert(3, str(max_stashed_blocks) + "\n")
601
602    with open(prefix + ".transfer.list", "wb") as f:
603      for i in out:
604        f.write(i)
605
606    if self.version >= 2:
607      self._max_stashed_size = max_stashed_blocks * self.tgt.blocksize
608      OPTIONS = common.OPTIONS
609      if OPTIONS.cache_size is not None:
610        max_allowed = OPTIONS.cache_size * OPTIONS.stash_threshold
611        print("max stashed blocks: %d  (%d bytes), "
612              "limit: %d bytes (%.2f%%)\n" % (
613              max_stashed_blocks, self._max_stashed_size, max_allowed,
614              self._max_stashed_size * 100.0 / max_allowed))
615      else:
616        print("max stashed blocks: %d  (%d bytes), limit: <unknown>\n" % (
617              max_stashed_blocks, self._max_stashed_size))
618
619  def ReviseStashSize(self):
620    print("Revising stash size...")
621    stashes = {}
622
623    # Create the map between a stash and its def/use points. For example, for a
624    # given stash of (idx, sr), stashes[idx] = (sr, def_cmd, use_cmd).
625    for xf in self.transfers:
626      # Command xf defines (stores) all the stashes in stash_before.
627      for idx, sr in xf.stash_before:
628        stashes[idx] = (sr, xf)
629
630      # Record all the stashes command xf uses.
631      for idx, _ in xf.use_stash:
632        stashes[idx] += (xf,)
633
634    # Compute the maximum blocks available for stash based on /cache size and
635    # the threshold.
636    cache_size = common.OPTIONS.cache_size
637    stash_threshold = common.OPTIONS.stash_threshold
638    max_allowed = cache_size * stash_threshold / self.tgt.blocksize
639
640    stashed_blocks = 0
641    new_blocks = 0
642
643    # Now go through all the commands. Compute the required stash size on the
644    # fly. If a command requires excess stash than available, it deletes the
645    # stash by replacing the command that uses the stash with a "new" command
646    # instead.
647    for xf in self.transfers:
648      replaced_cmds = []
649
650      # xf.stash_before generates explicit stash commands.
651      for idx, sr in xf.stash_before:
652        if stashed_blocks + sr.size() > max_allowed:
653          # We cannot stash this one for a later command. Find out the command
654          # that will use this stash and replace the command with "new".
655          use_cmd = stashes[idx][2]
656          replaced_cmds.append(use_cmd)
657          print("%10d  %9s  %s" % (sr.size(), "explicit", use_cmd))
658        else:
659          stashed_blocks += sr.size()
660
661      # xf.use_stash generates free commands.
662      for _, sr in xf.use_stash:
663        stashed_blocks -= sr.size()
664
665      # "move" and "diff" may introduce implicit stashes in BBOTA v3. Prior to
666      # ComputePatches(), they both have the style of "diff".
667      if xf.style == "diff" and self.version >= 3:
668        assert xf.tgt_ranges and xf.src_ranges
669        if xf.src_ranges.overlaps(xf.tgt_ranges):
670          if stashed_blocks + xf.src_ranges.size() > max_allowed:
671            replaced_cmds.append(xf)
672            print("%10d  %9s  %s" % (xf.src_ranges.size(), "implicit", xf))
673
674      # Replace the commands in replaced_cmds with "new"s.
675      for cmd in replaced_cmds:
676        # It no longer uses any commands in "use_stash". Remove the def points
677        # for all those stashes.
678        for idx, sr in cmd.use_stash:
679          def_cmd = stashes[idx][1]
680          assert (idx, sr) in def_cmd.stash_before
681          def_cmd.stash_before.remove((idx, sr))
682
683        # Add up blocks that violates space limit and print total number to
684        # screen later.
685        new_blocks += cmd.tgt_ranges.size()
686        cmd.ConvertToNew()
687
688    num_of_bytes = new_blocks * self.tgt.blocksize
689    print("  Total %d blocks (%d bytes) are packed as new blocks due to "
690          "insufficient cache size." % (new_blocks, num_of_bytes))
691
692  def ComputePatches(self, prefix):
693    print("Reticulating splines...")
694    diff_q = []
695    patch_num = 0
696    with open(prefix + ".new.dat", "wb") as new_f:
697      for xf in self.transfers:
698        if xf.style == "zero":
699          pass
700        elif xf.style == "new":
701          for piece in self.tgt.ReadRangeSet(xf.tgt_ranges):
702            new_f.write(piece)
703        elif xf.style == "diff":
704          src = self.src.ReadRangeSet(xf.src_ranges)
705          tgt = self.tgt.ReadRangeSet(xf.tgt_ranges)
706
707          # We can't compare src and tgt directly because they may have
708          # the same content but be broken up into blocks differently, eg:
709          #
710          #    ["he", "llo"]  vs  ["h", "ello"]
711          #
712          # We want those to compare equal, ideally without having to
713          # actually concatenate the strings (these may be tens of
714          # megabytes).
715
716          src_sha1 = sha1()
717          for p in src:
718            src_sha1.update(p)
719          tgt_sha1 = sha1()
720          tgt_size = 0
721          for p in tgt:
722            tgt_sha1.update(p)
723            tgt_size += len(p)
724
725          if src_sha1.digest() == tgt_sha1.digest():
726            # These are identical; we don't need to generate a patch,
727            # just issue copy commands on the device.
728            xf.style = "move"
729          else:
730            # For files in zip format (eg, APKs, JARs, etc.) we would
731            # like to use imgdiff -z if possible (because it usually
732            # produces significantly smaller patches than bsdiff).
733            # This is permissible if:
734            #
735            #  - imgdiff is not disabled, and
736            #  - the source and target files are monotonic (ie, the
737            #    data is stored with blocks in increasing order), and
738            #  - we haven't removed any blocks from the source set.
739            #
740            # If these conditions are satisfied then appending all the
741            # blocks in the set together in order will produce a valid
742            # zip file (plus possibly extra zeros in the last block),
743            # which is what imgdiff needs to operate.  (imgdiff is
744            # fine with extra zeros at the end of the file.)
745            imgdiff = (not self.disable_imgdiff and xf.intact and
746                       xf.tgt_name.split(".")[-1].lower()
747                       in ("apk", "jar", "zip"))
748            xf.style = "imgdiff" if imgdiff else "bsdiff"
749            diff_q.append((tgt_size, src, tgt, xf, patch_num))
750            patch_num += 1
751
752        else:
753          assert False, "unknown style " + xf.style
754
755    if diff_q:
756      if self.threads > 1:
757        print("Computing patches (using %d threads)..." % (self.threads,))
758      else:
759        print("Computing patches...")
760      diff_q.sort()
761
762      patches = [None] * patch_num
763
764      # TODO: Rewrite with multiprocessing.ThreadPool?
765      lock = threading.Lock()
766      def diff_worker():
767        while True:
768          with lock:
769            if not diff_q:
770              return
771            tgt_size, src, tgt, xf, patchnum = diff_q.pop()
772          patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff"))
773          size = len(patch)
774          with lock:
775            patches[patchnum] = (patch, xf)
776            print("%10d %10d (%6.2f%%) %7s %s" % (
777                size, tgt_size, size * 100.0 / tgt_size, xf.style,
778                xf.tgt_name if xf.tgt_name == xf.src_name else (
779                    xf.tgt_name + " (from " + xf.src_name + ")")))
780
781      threads = [threading.Thread(target=diff_worker)
782                 for _ in range(self.threads)]
783      for th in threads:
784        th.start()
785      while threads:
786        threads.pop().join()
787    else:
788      patches = []
789
790    p = 0
791    with open(prefix + ".patch.dat", "wb") as patch_f:
792      for patch, xf in patches:
793        xf.patch_start = p
794        xf.patch_len = len(patch)
795        patch_f.write(patch)
796        p += len(patch)
797
798  def AssertSequenceGood(self):
799    # Simulate the sequences of transfers we will output, and check that:
800    # - we never read a block after writing it, and
801    # - we write every block we care about exactly once.
802
803    # Start with no blocks having been touched yet.
804    touched = array.array("B", "\0" * self.tgt.total_blocks)
805
806    # Imagine processing the transfers in order.
807    for xf in self.transfers:
808      # Check that the input blocks for this transfer haven't yet been touched.
809
810      x = xf.src_ranges
811      if self.version >= 2:
812        for _, sr in xf.use_stash:
813          x = x.subtract(sr)
814
815      for s, e in x:
816        # Source image could be larger. Don't check the blocks that are in the
817        # source image only. Since they are not in 'touched', and won't ever
818        # be touched.
819        for i in range(s, min(e, self.tgt.total_blocks)):
820          assert touched[i] == 0
821
822      # Check that the output blocks for this transfer haven't yet
823      # been touched, and touch all the blocks written by this
824      # transfer.
825      for s, e in xf.tgt_ranges:
826        for i in range(s, e):
827          assert touched[i] == 0
828          touched[i] = 1
829
830    # Check that we've written every target block.
831    for s, e in self.tgt.care_map:
832      for i in range(s, e):
833        assert touched[i] == 1
834
835  def ImproveVertexSequence(self):
836    print("Improving vertex order...")
837
838    # At this point our digraph is acyclic; we reversed any edges that
839    # were backwards in the heuristically-generated sequence.  The
840    # previously-generated order is still acceptable, but we hope to
841    # find a better order that needs less memory for stashed data.
842    # Now we do a topological sort to generate a new vertex order,
843    # using a greedy algorithm to choose which vertex goes next
844    # whenever we have a choice.
845
846    # Make a copy of the edge set; this copy will get destroyed by the
847    # algorithm.
848    for xf in self.transfers:
849      xf.incoming = xf.goes_after.copy()
850      xf.outgoing = xf.goes_before.copy()
851
852    L = []   # the new vertex order
853
854    # S is the set of sources in the remaining graph; we always choose
855    # the one that leaves the least amount of stashed data after it's
856    # executed.
857    S = [(u.NetStashChange(), u.order, u) for u in self.transfers
858         if not u.incoming]
859    heapq.heapify(S)
860
861    while S:
862      _, _, xf = heapq.heappop(S)
863      L.append(xf)
864      for u in xf.outgoing:
865        del u.incoming[xf]
866        if not u.incoming:
867          heapq.heappush(S, (u.NetStashChange(), u.order, u))
868
869    # if this fails then our graph had a cycle.
870    assert len(L) == len(self.transfers)
871
872    self.transfers = L
873    for i, xf in enumerate(L):
874      xf.order = i
875
876  def RemoveBackwardEdges(self):
877    print("Removing backward edges...")
878    in_order = 0
879    out_of_order = 0
880    lost_source = 0
881
882    for xf in self.transfers:
883      lost = 0
884      size = xf.src_ranges.size()
885      for u in xf.goes_before:
886        # xf should go before u
887        if xf.order < u.order:
888          # it does, hurray!
889          in_order += 1
890        else:
891          # it doesn't, boo.  trim the blocks that u writes from xf's
892          # source, so that xf can go after u.
893          out_of_order += 1
894          assert xf.src_ranges.overlaps(u.tgt_ranges)
895          xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges)
896          xf.intact = False
897
898      if xf.style == "diff" and not xf.src_ranges:
899        # nothing left to diff from; treat as new data
900        xf.style = "new"
901
902      lost = size - xf.src_ranges.size()
903      lost_source += lost
904
905    print(("  %d/%d dependencies (%.2f%%) were violated; "
906           "%d source blocks removed.") %
907          (out_of_order, in_order + out_of_order,
908           (out_of_order * 100.0 / (in_order + out_of_order))
909           if (in_order + out_of_order) else 0.0,
910           lost_source))
911
912  def ReverseBackwardEdges(self):
913    print("Reversing backward edges...")
914    in_order = 0
915    out_of_order = 0
916    stashes = 0
917    stash_size = 0
918
919    for xf in self.transfers:
920      for u in xf.goes_before.copy():
921        # xf should go before u
922        if xf.order < u.order:
923          # it does, hurray!
924          in_order += 1
925        else:
926          # it doesn't, boo.  modify u to stash the blocks that it
927          # writes that xf wants to read, and then require u to go
928          # before xf.
929          out_of_order += 1
930
931          overlap = xf.src_ranges.intersect(u.tgt_ranges)
932          assert overlap
933
934          u.stash_before.append((stashes, overlap))
935          xf.use_stash.append((stashes, overlap))
936          stashes += 1
937          stash_size += overlap.size()
938
939          # reverse the edge direction; now xf must go after u
940          del xf.goes_before[u]
941          del u.goes_after[xf]
942          xf.goes_after[u] = None    # value doesn't matter
943          u.goes_before[xf] = None
944
945    print(("  %d/%d dependencies (%.2f%%) were violated; "
946           "%d source blocks stashed.") %
947          (out_of_order, in_order + out_of_order,
948           (out_of_order * 100.0 / (in_order + out_of_order))
949           if (in_order + out_of_order) else 0.0,
950           stash_size))
951
952  def FindVertexSequence(self):
953    print("Finding vertex sequence...")
954
955    # This is based on "A Fast & Effective Heuristic for the Feedback
956    # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth.  Think of
957    # it as starting with the digraph G and moving all the vertices to
958    # be on a horizontal line in some order, trying to minimize the
959    # number of edges that end up pointing to the left.  Left-pointing
960    # edges will get removed to turn the digraph into a DAG.  In this
961    # case each edge has a weight which is the number of source blocks
962    # we'll lose if that edge is removed; we try to minimize the total
963    # weight rather than just the number of edges.
964
965    # Make a copy of the edge set; this copy will get destroyed by the
966    # algorithm.
967    for xf in self.transfers:
968      xf.incoming = xf.goes_after.copy()
969      xf.outgoing = xf.goes_before.copy()
970      xf.score = sum(xf.outgoing.values()) - sum(xf.incoming.values())
971
972    # We use an OrderedDict instead of just a set so that the output
973    # is repeatable; otherwise it would depend on the hash values of
974    # the transfer objects.
975    G = OrderedDict()
976    for xf in self.transfers:
977      G[xf] = None
978    s1 = deque()  # the left side of the sequence, built from left to right
979    s2 = deque()  # the right side of the sequence, built from right to left
980
981    heap = []
982    for xf in self.transfers:
983      xf.heap_item = HeapItem(xf)
984      heap.append(xf.heap_item)
985    heapq.heapify(heap)
986
987    sinks = set(u for u in G if not u.outgoing)
988    sources = set(u for u in G if not u.incoming)
989
990    def adjust_score(iu, delta):
991      iu.score += delta
992      iu.heap_item.clear()
993      iu.heap_item = HeapItem(iu)
994      heapq.heappush(heap, iu.heap_item)
995
996    while G:
997      # Put all sinks at the end of the sequence.
998      while sinks:
999        new_sinks = set()
1000        for u in sinks:
1001          if u not in G: continue
1002          s2.appendleft(u)
1003          del G[u]
1004          for iu in u.incoming:
1005            adjust_score(iu, -iu.outgoing.pop(u))
1006            if not iu.outgoing: new_sinks.add(iu)
1007        sinks = new_sinks
1008
1009      # Put all the sources at the beginning of the sequence.
1010      while sources:
1011        new_sources = set()
1012        for u in sources:
1013          if u not in G: continue
1014          s1.append(u)
1015          del G[u]
1016          for iu in u.outgoing:
1017            adjust_score(iu, +iu.incoming.pop(u))
1018            if not iu.incoming: new_sources.add(iu)
1019        sources = new_sources
1020
1021      if not G: break
1022
1023      # Find the "best" vertex to put next.  "Best" is the one that
1024      # maximizes the net difference in source blocks saved we get by
1025      # pretending it's a source rather than a sink.
1026
1027      while True:
1028        u = heapq.heappop(heap)
1029        if u and u.item in G:
1030          u = u.item
1031          break
1032
1033      s1.append(u)
1034      del G[u]
1035      for iu in u.outgoing:
1036        adjust_score(iu, +iu.incoming.pop(u))
1037        if not iu.incoming: sources.add(iu)
1038
1039      for iu in u.incoming:
1040        adjust_score(iu, -iu.outgoing.pop(u))
1041        if not iu.outgoing: sinks.add(iu)
1042
1043    # Now record the sequence in the 'order' field of each transfer,
1044    # and by rearranging self.transfers to be in the chosen sequence.
1045
1046    new_transfers = []
1047    for x in itertools.chain(s1, s2):
1048      x.order = len(new_transfers)
1049      new_transfers.append(x)
1050      del x.incoming
1051      del x.outgoing
1052
1053    self.transfers = new_transfers
1054
1055  def GenerateDigraph(self):
1056    print("Generating digraph...")
1057
1058    # Each item of source_ranges will be:
1059    #   - None, if that block is not used as a source,
1060    #   - a transfer, if one transfer uses it as a source, or
1061    #   - a set of transfers.
1062    source_ranges = []
1063    for b in self.transfers:
1064      for s, e in b.src_ranges:
1065        if e > len(source_ranges):
1066          source_ranges.extend([None] * (e-len(source_ranges)))
1067        for i in range(s, e):
1068          if source_ranges[i] is None:
1069            source_ranges[i] = b
1070          else:
1071            if not isinstance(source_ranges[i], set):
1072              source_ranges[i] = set([source_ranges[i]])
1073            source_ranges[i].add(b)
1074
1075    for a in self.transfers:
1076      intersections = set()
1077      for s, e in a.tgt_ranges:
1078        for i in range(s, e):
1079          if i >= len(source_ranges): break
1080          b = source_ranges[i]
1081          if b is not None:
1082            if isinstance(b, set):
1083              intersections.update(b)
1084            else:
1085              intersections.add(b)
1086
1087      for b in intersections:
1088        if a is b: continue
1089
1090        # If the blocks written by A are read by B, then B needs to go before A.
1091        i = a.tgt_ranges.intersect(b.src_ranges)
1092        if i:
1093          if b.src_name == "__ZERO":
1094            # the cost of removing source blocks for the __ZERO domain
1095            # is (nearly) zero.
1096            size = 0
1097          else:
1098            size = i.size()
1099          b.goes_before[a] = size
1100          a.goes_after[b] = size
1101
1102  def FindTransfers(self):
1103    """Parse the file_map to generate all the transfers."""
1104
1105    def AddTransfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id,
1106                    split=False):
1107      """Wrapper function for adding a Transfer().
1108
1109      For BBOTA v3, we need to stash source blocks for resumable feature.
1110      However, with the growth of file size and the shrink of the cache
1111      partition source blocks are too large to be stashed. If a file occupies
1112      too many blocks (greater than MAX_BLOCKS_PER_DIFF_TRANSFER), we split it
1113      into smaller pieces by getting multiple Transfer()s.
1114
1115      The downside is that after splitting, we may increase the package size
1116      since the split pieces don't align well. According to our experiments,
1117      1/8 of the cache size as the per-piece limit appears to be optimal.
1118      Compared to the fixed 1024-block limit, it reduces the overall package
1119      size by 30% volantis, and 20% for angler and bullhead."""
1120
1121      # We care about diff transfers only.
1122      if style != "diff" or not split:
1123        Transfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id)
1124        return
1125
1126      pieces = 0
1127      cache_size = common.OPTIONS.cache_size
1128      split_threshold = 0.125
1129      max_blocks_per_transfer = int(cache_size * split_threshold /
1130                                    self.tgt.blocksize)
1131
1132      # Change nothing for small files.
1133      if (tgt_ranges.size() <= max_blocks_per_transfer and
1134          src_ranges.size() <= max_blocks_per_transfer):
1135        Transfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id)
1136        return
1137
1138      while (tgt_ranges.size() > max_blocks_per_transfer and
1139             src_ranges.size() > max_blocks_per_transfer):
1140        tgt_split_name = "%s-%d" % (tgt_name, pieces)
1141        src_split_name = "%s-%d" % (src_name, pieces)
1142        tgt_first = tgt_ranges.first(max_blocks_per_transfer)
1143        src_first = src_ranges.first(max_blocks_per_transfer)
1144
1145        Transfer(tgt_split_name, src_split_name, tgt_first, src_first, style,
1146                 by_id)
1147
1148        tgt_ranges = tgt_ranges.subtract(tgt_first)
1149        src_ranges = src_ranges.subtract(src_first)
1150        pieces += 1
1151
1152      # Handle remaining blocks.
1153      if tgt_ranges.size() or src_ranges.size():
1154        # Must be both non-empty.
1155        assert tgt_ranges.size() and src_ranges.size()
1156        tgt_split_name = "%s-%d" % (tgt_name, pieces)
1157        src_split_name = "%s-%d" % (src_name, pieces)
1158        Transfer(tgt_split_name, src_split_name, tgt_ranges, src_ranges, style,
1159                 by_id)
1160
1161    empty = RangeSet()
1162    for tgt_fn, tgt_ranges in self.tgt.file_map.items():
1163      if tgt_fn == "__ZERO":
1164        # the special "__ZERO" domain is all the blocks not contained
1165        # in any file and that are filled with zeros.  We have a
1166        # special transfer style for zero blocks.
1167        src_ranges = self.src.file_map.get("__ZERO", empty)
1168        AddTransfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges,
1169                    "zero", self.transfers)
1170        continue
1171
1172      elif tgt_fn == "__COPY":
1173        # "__COPY" domain includes all the blocks not contained in any
1174        # file and that need to be copied unconditionally to the target.
1175        AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
1176        continue
1177
1178      elif tgt_fn in self.src.file_map:
1179        # Look for an exact pathname match in the source.
1180        AddTransfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn],
1181                    "diff", self.transfers, self.version >= 3)
1182        continue
1183
1184      b = os.path.basename(tgt_fn)
1185      if b in self.src_basenames:
1186        # Look for an exact basename match in the source.
1187        src_fn = self.src_basenames[b]
1188        AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
1189                    "diff", self.transfers, self.version >= 3)
1190        continue
1191
1192      b = re.sub("[0-9]+", "#", b)
1193      if b in self.src_numpatterns:
1194        # Look for a 'number pattern' match (a basename match after
1195        # all runs of digits are replaced by "#").  (This is useful
1196        # for .so files that contain version numbers in the filename
1197        # that get bumped.)
1198        src_fn = self.src_numpatterns[b]
1199        AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
1200                    "diff", self.transfers, self.version >= 3)
1201        continue
1202
1203      AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
1204
1205  def AbbreviateSourceNames(self):
1206    for k in self.src.file_map.keys():
1207      b = os.path.basename(k)
1208      self.src_basenames[b] = k
1209      b = re.sub("[0-9]+", "#", b)
1210      self.src_numpatterns[b] = k
1211
1212  @staticmethod
1213  def AssertPartition(total, seq):
1214    """Assert that all the RangeSets in 'seq' form a partition of the
1215    'total' RangeSet (ie, they are nonintersecting and their union
1216    equals 'total')."""
1217
1218    so_far = RangeSet()
1219    for i in seq:
1220      assert not so_far.overlaps(i)
1221      so_far = so_far.union(i)
1222    assert so_far == total
1223