• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (C) 2014 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import print_function
16
17from collections import deque, OrderedDict
18from hashlib import sha1
19import itertools
20import multiprocessing
21import os
22import pprint
23import re
24import subprocess
25import sys
26import threading
27import tempfile
28
29from rangelib import *
30
31__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]
32
33def compute_patch(src, tgt, imgdiff=False):
34  srcfd, srcfile = tempfile.mkstemp(prefix="src-")
35  tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-")
36  patchfd, patchfile = tempfile.mkstemp(prefix="patch-")
37  os.close(patchfd)
38
39  try:
40    with os.fdopen(srcfd, "wb") as f_src:
41      for p in src:
42        f_src.write(p)
43
44    with os.fdopen(tgtfd, "wb") as f_tgt:
45      for p in tgt:
46        f_tgt.write(p)
47    try:
48      os.unlink(patchfile)
49    except OSError:
50      pass
51    if imgdiff:
52      p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile],
53                          stdout=open("/dev/null", "a"),
54                          stderr=subprocess.STDOUT)
55    else:
56      p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile])
57
58    if p:
59      raise ValueError("diff failed: " + str(p))
60
61    with open(patchfile, "rb") as f:
62      return f.read()
63  finally:
64    try:
65      os.unlink(srcfile)
66      os.unlink(tgtfile)
67      os.unlink(patchfile)
68    except OSError:
69      pass
70
71class EmptyImage(object):
72  """A zero-length image."""
73  blocksize = 4096
74  care_map = RangeSet()
75  total_blocks = 0
76  file_map = {}
77  def ReadRangeSet(self, ranges):
78    return ()
79  def TotalSha1(self):
80    return sha1().hexdigest()
81
82
83class DataImage(object):
84  """An image wrapped around a single string of data."""
85
86  def __init__(self, data, trim=False, pad=False):
87    self.data = data
88    self.blocksize = 4096
89
90    assert not (trim and pad)
91
92    partial = len(self.data) % self.blocksize
93    if partial > 0:
94      if trim:
95        self.data = self.data[:-partial]
96      elif pad:
97        self.data += '\0' * (self.blocksize - partial)
98      else:
99        raise ValueError(("data for DataImage must be multiple of %d bytes "
100                          "unless trim or pad is specified") %
101                         (self.blocksize,))
102
103    assert len(self.data) % self.blocksize == 0
104
105    self.total_blocks = len(self.data) / self.blocksize
106    self.care_map = RangeSet(data=(0, self.total_blocks))
107
108    zero_blocks = []
109    nonzero_blocks = []
110    reference = '\0' * self.blocksize
111
112    for i in range(self.total_blocks):
113      d = self.data[i*self.blocksize : (i+1)*self.blocksize]
114      if d == reference:
115        zero_blocks.append(i)
116        zero_blocks.append(i+1)
117      else:
118        nonzero_blocks.append(i)
119        nonzero_blocks.append(i+1)
120
121    self.file_map = {"__ZERO": RangeSet(zero_blocks),
122                     "__NONZERO": RangeSet(nonzero_blocks)}
123
124  def ReadRangeSet(self, ranges):
125    return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges]
126
127  def TotalSha1(self):
128    if not hasattr(self, "sha1"):
129      self.sha1 = sha1(self.data).hexdigest()
130    return self.sha1
131
132
133class Transfer(object):
134  def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id):
135    self.tgt_name = tgt_name
136    self.src_name = src_name
137    self.tgt_ranges = tgt_ranges
138    self.src_ranges = src_ranges
139    self.style = style
140    self.intact = (getattr(tgt_ranges, "monotonic", False) and
141                   getattr(src_ranges, "monotonic", False))
142    self.goes_before = {}
143    self.goes_after = {}
144
145    self.id = len(by_id)
146    by_id.append(self)
147
148  def __str__(self):
149    return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
150            " to " + str(self.tgt_ranges) + ">")
151
152
153# BlockImageDiff works on two image objects.  An image object is
154# anything that provides the following attributes:
155#
156#    blocksize: the size in bytes of a block, currently must be 4096.
157#
158#    total_blocks: the total size of the partition/image, in blocks.
159#
160#    care_map: a RangeSet containing which blocks (in the range [0,
161#      total_blocks) we actually care about; i.e. which blocks contain
162#      data.
163#
164#    file_map: a dict that partitions the blocks contained in care_map
165#      into smaller domains that are useful for doing diffs on.
166#      (Typically a domain is a file, and the key in file_map is the
167#      pathname.)
168#
169#    ReadRangeSet(): a function that takes a RangeSet and returns the
170#      data contained in the image blocks of that RangeSet.  The data
171#      is returned as a list or tuple of strings; concatenating the
172#      elements together should produce the requested data.
173#      Implementations are free to break up the data into list/tuple
174#      elements in any way that is convenient.
175#
176#    TotalSha1(): a function that returns (as a hex string) the SHA-1
177#      hash of all the data in the image (ie, all the blocks in the
178#      care_map)
179#
180# When creating a BlockImageDiff, the src image may be None, in which
181# case the list of transfers produced will never read from the
182# original image.
183
184class BlockImageDiff(object):
185  def __init__(self, tgt, src=None, threads=None):
186    if threads is None:
187      threads = multiprocessing.cpu_count() // 2
188      if threads == 0: threads = 1
189    self.threads = threads
190
191    self.tgt = tgt
192    if src is None:
193      src = EmptyImage()
194    self.src = src
195
196    # The updater code that installs the patch always uses 4k blocks.
197    assert tgt.blocksize == 4096
198    assert src.blocksize == 4096
199
200    # The range sets in each filemap should comprise a partition of
201    # the care map.
202    self.AssertPartition(src.care_map, src.file_map.values())
203    self.AssertPartition(tgt.care_map, tgt.file_map.values())
204
205  def Compute(self, prefix):
206    # When looking for a source file to use as the diff input for a
207    # target file, we try:
208    #   1) an exact path match if available, otherwise
209    #   2) a exact basename match if available, otherwise
210    #   3) a basename match after all runs of digits are replaced by
211    #      "#" if available, otherwise
212    #   4) we have no source for this target.
213    self.AbbreviateSourceNames()
214    self.FindTransfers()
215
216    # Find the ordering dependencies among transfers (this is O(n^2)
217    # in the number of transfers).
218    self.GenerateDigraph()
219    # Find a sequence of transfers that satisfies as many ordering
220    # dependencies as possible (heuristically).
221    self.FindVertexSequence()
222    # Fix up the ordering dependencies that the sequence didn't
223    # satisfy.
224    self.RemoveBackwardEdges()
225    # Double-check our work.
226    self.AssertSequenceGood()
227
228    self.ComputePatches(prefix)
229    self.WriteTransfers(prefix)
230
231  def WriteTransfers(self, prefix):
232    out = []
233
234    out.append("1\n")   # format version number
235    total = 0
236    performs_read = False
237
238    for xf in self.transfers:
239
240      # zero [rangeset]
241      # new [rangeset]
242      # bsdiff patchstart patchlen [src rangeset] [tgt rangeset]
243      # imgdiff patchstart patchlen [src rangeset] [tgt rangeset]
244      # move [src rangeset] [tgt rangeset]
245      # erase [rangeset]
246
247      tgt_size = xf.tgt_ranges.size()
248
249      if xf.style == "new":
250        assert xf.tgt_ranges
251        out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw()))
252        total += tgt_size
253      elif xf.style == "move":
254        performs_read = True
255        assert xf.tgt_ranges
256        assert xf.src_ranges.size() == tgt_size
257        if xf.src_ranges != xf.tgt_ranges:
258          out.append("%s %s %s\n" % (
259              xf.style,
260              xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
261          total += tgt_size
262      elif xf.style in ("bsdiff", "imgdiff"):
263        performs_read = True
264        assert xf.tgt_ranges
265        assert xf.src_ranges
266        out.append("%s %d %d %s %s\n" % (
267            xf.style, xf.patch_start, xf.patch_len,
268            xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
269        total += tgt_size
270      elif xf.style == "zero":
271        assert xf.tgt_ranges
272        to_zero = xf.tgt_ranges.subtract(xf.src_ranges)
273        if to_zero:
274          out.append("%s %s\n" % (xf.style, to_zero.to_string_raw()))
275          total += to_zero.size()
276      else:
277        raise ValueError, "unknown transfer style '%s'\n" % (xf.style,)
278
279    out.insert(1, str(total) + "\n")
280
281    all_tgt = RangeSet(data=(0, self.tgt.total_blocks))
282    if performs_read:
283      # if some of the original data is used, then at the end we'll
284      # erase all the blocks on the partition that don't contain data
285      # in the new image.
286      new_dontcare = all_tgt.subtract(self.tgt.care_map)
287      if new_dontcare:
288        out.append("erase %s\n" % (new_dontcare.to_string_raw(),))
289    else:
290      # if nothing is read (ie, this is a full OTA), then we can start
291      # by erasing the entire partition.
292      out.insert(2, "erase %s\n" % (all_tgt.to_string_raw(),))
293
294    with open(prefix + ".transfer.list", "wb") as f:
295      for i in out:
296        f.write(i)
297
298  def ComputePatches(self, prefix):
299    print("Reticulating splines...")
300    diff_q = []
301    patch_num = 0
302    with open(prefix + ".new.dat", "wb") as new_f:
303      for xf in self.transfers:
304        if xf.style == "zero":
305          pass
306        elif xf.style == "new":
307          for piece in self.tgt.ReadRangeSet(xf.tgt_ranges):
308            new_f.write(piece)
309        elif xf.style == "diff":
310          src = self.src.ReadRangeSet(xf.src_ranges)
311          tgt = self.tgt.ReadRangeSet(xf.tgt_ranges)
312
313          # We can't compare src and tgt directly because they may have
314          # the same content but be broken up into blocks differently, eg:
315          #
316          #    ["he", "llo"]  vs  ["h", "ello"]
317          #
318          # We want those to compare equal, ideally without having to
319          # actually concatenate the strings (these may be tens of
320          # megabytes).
321
322          src_sha1 = sha1()
323          for p in src:
324            src_sha1.update(p)
325          tgt_sha1 = sha1()
326          tgt_size = 0
327          for p in tgt:
328            tgt_sha1.update(p)
329            tgt_size += len(p)
330
331          if src_sha1.digest() == tgt_sha1.digest():
332            # These are identical; we don't need to generate a patch,
333            # just issue copy commands on the device.
334            xf.style = "move"
335          else:
336            # For files in zip format (eg, APKs, JARs, etc.) we would
337            # like to use imgdiff -z if possible (because it usually
338            # produces significantly smaller patches than bsdiff).
339            # This is permissible if:
340            #
341            #  - the source and target files are monotonic (ie, the
342            #    data is stored with blocks in increasing order), and
343            #  - we haven't removed any blocks from the source set.
344            #
345            # If these conditions are satisfied then appending all the
346            # blocks in the set together in order will produce a valid
347            # zip file (plus possibly extra zeros in the last block),
348            # which is what imgdiff needs to operate.  (imgdiff is
349            # fine with extra zeros at the end of the file.)
350            imgdiff = (xf.intact and
351                       xf.tgt_name.split(".")[-1].lower()
352                       in ("apk", "jar", "zip"))
353            xf.style = "imgdiff" if imgdiff else "bsdiff"
354            diff_q.append((tgt_size, src, tgt, xf, patch_num))
355            patch_num += 1
356
357        else:
358          assert False, "unknown style " + xf.style
359
360    if diff_q:
361      if self.threads > 1:
362        print("Computing patches (using %d threads)..." % (self.threads,))
363      else:
364        print("Computing patches...")
365      diff_q.sort()
366
367      patches = [None] * patch_num
368
369      lock = threading.Lock()
370      def diff_worker():
371        while True:
372          with lock:
373            if not diff_q: return
374            tgt_size, src, tgt, xf, patchnum = diff_q.pop()
375          patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff"))
376          size = len(patch)
377          with lock:
378            patches[patchnum] = (patch, xf)
379            print("%10d %10d (%6.2f%%) %7s %s" % (
380                size, tgt_size, size * 100.0 / tgt_size, xf.style,
381                xf.tgt_name if xf.tgt_name == xf.src_name else (
382                    xf.tgt_name + " (from " + xf.src_name + ")")))
383
384      threads = [threading.Thread(target=diff_worker)
385                 for i in range(self.threads)]
386      for th in threads:
387        th.start()
388      while threads:
389        threads.pop().join()
390    else:
391      patches = []
392
393    p = 0
394    with open(prefix + ".patch.dat", "wb") as patch_f:
395      for patch, xf in patches:
396        xf.patch_start = p
397        xf.patch_len = len(patch)
398        patch_f.write(patch)
399        p += len(patch)
400
401  def AssertSequenceGood(self):
402    # Simulate the sequences of transfers we will output, and check that:
403    # - we never read a block after writing it, and
404    # - we write every block we care about exactly once.
405
406    # Start with no blocks having been touched yet.
407    touched = RangeSet()
408
409    # Imagine processing the transfers in order.
410    for xf in self.transfers:
411      # Check that the input blocks for this transfer haven't yet been touched.
412      assert not touched.overlaps(xf.src_ranges)
413      # Check that the output blocks for this transfer haven't yet been touched.
414      assert not touched.overlaps(xf.tgt_ranges)
415      # Touch all the blocks written by this transfer.
416      touched = touched.union(xf.tgt_ranges)
417
418    # Check that we've written every target block.
419    assert touched == self.tgt.care_map
420
421  def RemoveBackwardEdges(self):
422    print("Removing backward edges...")
423    in_order = 0
424    out_of_order = 0
425    lost_source = 0
426
427    for xf in self.transfers:
428      io = 0
429      ooo = 0
430      lost = 0
431      size = xf.src_ranges.size()
432      for u in xf.goes_before:
433        # xf should go before u
434        if xf.order < u.order:
435          # it does, hurray!
436          io += 1
437        else:
438          # it doesn't, boo.  trim the blocks that u writes from xf's
439          # source, so that xf can go after u.
440          ooo += 1
441          assert xf.src_ranges.overlaps(u.tgt_ranges)
442          xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges)
443          xf.intact = False
444
445      if xf.style == "diff" and not xf.src_ranges:
446        # nothing left to diff from; treat as new data
447        xf.style = "new"
448
449      lost = size - xf.src_ranges.size()
450      lost_source += lost
451      in_order += io
452      out_of_order += ooo
453
454    print(("  %d/%d dependencies (%.2f%%) were violated; "
455           "%d source blocks removed.") %
456          (out_of_order, in_order + out_of_order,
457           (out_of_order * 100.0 / (in_order + out_of_order))
458           if (in_order + out_of_order) else 0.0,
459           lost_source))
460
461  def FindVertexSequence(self):
462    print("Finding vertex sequence...")
463
464    # This is based on "A Fast & Effective Heuristic for the Feedback
465    # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth.  Think of
466    # it as starting with the digraph G and moving all the vertices to
467    # be on a horizontal line in some order, trying to minimize the
468    # number of edges that end up pointing to the left.  Left-pointing
469    # edges will get removed to turn the digraph into a DAG.  In this
470    # case each edge has a weight which is the number of source blocks
471    # we'll lose if that edge is removed; we try to minimize the total
472    # weight rather than just the number of edges.
473
474    # Make a copy of the edge set; this copy will get destroyed by the
475    # algorithm.
476    for xf in self.transfers:
477      xf.incoming = xf.goes_after.copy()
478      xf.outgoing = xf.goes_before.copy()
479
480    # We use an OrderedDict instead of just a set so that the output
481    # is repeatable; otherwise it would depend on the hash values of
482    # the transfer objects.
483    G = OrderedDict()
484    for xf in self.transfers:
485      G[xf] = None
486    s1 = deque()  # the left side of the sequence, built from left to right
487    s2 = deque()  # the right side of the sequence, built from right to left
488
489    while G:
490
491      # Put all sinks at the end of the sequence.
492      while True:
493        sinks = [u for u in G if not u.outgoing]
494        if not sinks: break
495        for u in sinks:
496          s2.appendleft(u)
497          del G[u]
498          for iu in u.incoming:
499            del iu.outgoing[u]
500
501      # Put all the sources at the beginning of the sequence.
502      while True:
503        sources = [u for u in G if not u.incoming]
504        if not sources: break
505        for u in sources:
506          s1.append(u)
507          del G[u]
508          for iu in u.outgoing:
509            del iu.incoming[u]
510
511      if not G: break
512
513      # Find the "best" vertex to put next.  "Best" is the one that
514      # maximizes the net difference in source blocks saved we get by
515      # pretending it's a source rather than a sink.
516
517      max_d = None
518      best_u = None
519      for u in G:
520        d = sum(u.outgoing.values()) - sum(u.incoming.values())
521        if best_u is None or d > max_d:
522          max_d = d
523          best_u = u
524
525      u = best_u
526      s1.append(u)
527      del G[u]
528      for iu in u.outgoing:
529        del iu.incoming[u]
530      for iu in u.incoming:
531        del iu.outgoing[u]
532
533    # Now record the sequence in the 'order' field of each transfer,
534    # and by rearranging self.transfers to be in the chosen sequence.
535
536    new_transfers = []
537    for x in itertools.chain(s1, s2):
538      x.order = len(new_transfers)
539      new_transfers.append(x)
540      del x.incoming
541      del x.outgoing
542
543    self.transfers = new_transfers
544
545  def GenerateDigraph(self):
546    print("Generating digraph...")
547    for a in self.transfers:
548      for b in self.transfers:
549        if a is b: continue
550
551        # If the blocks written by A are read by B, then B needs to go before A.
552        i = a.tgt_ranges.intersect(b.src_ranges)
553        if i:
554          if b.src_name == "__ZERO":
555            # the cost of removing source blocks for the __ZERO domain
556            # is (nearly) zero.
557            size = 0
558          else:
559            size = i.size()
560          b.goes_before[a] = size
561          a.goes_after[b] = size
562
563  def FindTransfers(self):
564    self.transfers = []
565    empty = RangeSet()
566    for tgt_fn, tgt_ranges in self.tgt.file_map.items():
567      if tgt_fn == "__ZERO":
568        # the special "__ZERO" domain is all the blocks not contained
569        # in any file and that are filled with zeros.  We have a
570        # special transfer style for zero blocks.
571        src_ranges = self.src.file_map.get("__ZERO", empty)
572        Transfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges,
573                 "zero", self.transfers)
574        continue
575
576      elif tgt_fn in self.src.file_map:
577        # Look for an exact pathname match in the source.
578        Transfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn],
579                 "diff", self.transfers)
580        continue
581
582      b = os.path.basename(tgt_fn)
583      if b in self.src_basenames:
584        # Look for an exact basename match in the source.
585        src_fn = self.src_basenames[b]
586        Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
587                 "diff", self.transfers)
588        continue
589
590      b = re.sub("[0-9]+", "#", b)
591      if b in self.src_numpatterns:
592        # Look for a 'number pattern' match (a basename match after
593        # all runs of digits are replaced by "#").  (This is useful
594        # for .so files that contain version numbers in the filename
595        # that get bumped.)
596        src_fn = self.src_numpatterns[b]
597        Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
598                 "diff", self.transfers)
599        continue
600
601      Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
602
603  def AbbreviateSourceNames(self):
604    self.src_basenames = {}
605    self.src_numpatterns = {}
606
607    for k in self.src.file_map.keys():
608      b = os.path.basename(k)
609      self.src_basenames[b] = k
610      b = re.sub("[0-9]+", "#", b)
611      self.src_numpatterns[b] = k
612
613  @staticmethod
614  def AssertPartition(total, seq):
615    """Assert that all the RangeSets in 'seq' form a partition of the
616    'total' RangeSet (ie, they are nonintersecting and their union
617    equals 'total')."""
618    so_far = RangeSet()
619    for i in seq:
620      assert not so_far.overlaps(i)
621      so_far = so_far.union(i)
622    assert so_far == total
623