• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python2
2#
3# Copyright (C) 2013 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#      http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18"""Unit testing checker.py."""
19
20from __future__ import print_function
21
22import array
23import collections
24import cStringIO
25import hashlib
26import itertools
27import os
28import unittest
29
30# pylint cannot find mox.
31# pylint: disable=F0401
32import mox
33
34from update_payload import checker
35from update_payload import common
36from update_payload import test_utils
37from update_payload import update_metadata_pb2
38from update_payload.error import PayloadError
39from update_payload.payload import Payload # Avoid name conflicts later.
40
41
42def _OpTypeByName(op_name):
43  """Returns the type of an operation from itsname."""
44  op_name_to_type = {
45      'REPLACE': common.OpType.REPLACE,
46      'REPLACE_BZ': common.OpType.REPLACE_BZ,
47      'MOVE': common.OpType.MOVE,
48      'BSDIFF': common.OpType.BSDIFF,
49      'SOURCE_COPY': common.OpType.SOURCE_COPY,
50      'SOURCE_BSDIFF': common.OpType.SOURCE_BSDIFF,
51      'ZERO': common.OpType.ZERO,
52      'DISCARD': common.OpType.DISCARD,
53      'REPLACE_XZ': common.OpType.REPLACE_XZ,
54      'PUFFDIFF': common.OpType.PUFFDIFF,
55      'BROTLI_BSDIFF': common.OpType.BROTLI_BSDIFF,
56  }
57  return op_name_to_type[op_name]
58
59
60def _GetPayloadChecker(payload_gen_write_to_file_func, payload_gen_dargs=None,
61                       checker_init_dargs=None):
62  """Returns a payload checker from a given payload generator."""
63  if payload_gen_dargs is None:
64    payload_gen_dargs = {}
65  if checker_init_dargs is None:
66    checker_init_dargs = {}
67
68  payload_file = cStringIO.StringIO()
69  payload_gen_write_to_file_func(payload_file, **payload_gen_dargs)
70  payload_file.seek(0)
71  payload = Payload(payload_file)
72  payload.Init()
73  return checker.PayloadChecker(payload, **checker_init_dargs)
74
75
76def _GetPayloadCheckerWithData(payload_gen):
77  """Returns a payload checker from a given payload generator."""
78  payload_file = cStringIO.StringIO()
79  payload_gen.WriteToFile(payload_file)
80  payload_file.seek(0)
81  payload = Payload(payload_file)
82  payload.Init()
83  return checker.PayloadChecker(payload)
84
85
86# This class doesn't need an __init__().
87# pylint: disable=W0232
88# Unit testing is all about running protected methods.
89# pylint: disable=W0212
90# Don't bark about missing members of classes you cannot import.
91# pylint: disable=E1101
92class PayloadCheckerTest(mox.MoxTestBase):
93  """Tests the PayloadChecker class.
94
95  In addition to ordinary testFoo() methods, which are automatically invoked by
96  the unittest framework, in this class we make use of DoBarTest() calls that
97  implement parametric tests of certain features. In order to invoke each test,
98  which embodies a unique combination of parameter values, as a complete unit
99  test, we perform explicit enumeration of the parameter space and create
100  individual invocation contexts for each, which are then bound as
101  testBar__param1=val1__param2=val2(). The enumeration of parameter spaces for
102  all such tests is done in AddAllParametricTests().
103  """
104
105  def MockPayload(self):
106    """Create a mock payload object, complete with a mock manifest."""
107    payload = self.mox.CreateMock(Payload)
108    payload.is_init = True
109    payload.manifest = self.mox.CreateMock(
110        update_metadata_pb2.DeltaArchiveManifest)
111    return payload
112
113  @staticmethod
114  def NewExtent(start_block, num_blocks):
115    """Returns an Extent message.
116
117    Each of the provided fields is set iff it is >= 0; otherwise, it's left at
118    its default state.
119
120    Args:
121      start_block: The starting block of the extent.
122      num_blocks: The number of blocks in the extent.
123
124    Returns:
125      An Extent message.
126    """
127    ex = update_metadata_pb2.Extent()
128    if start_block >= 0:
129      ex.start_block = start_block
130    if num_blocks >= 0:
131      ex.num_blocks = num_blocks
132    return ex
133
134  @staticmethod
135  def NewExtentList(*args):
136    """Returns an list of extents.
137
138    Args:
139      *args: (start_block, num_blocks) pairs defining the extents.
140
141    Returns:
142      A list of Extent objects.
143    """
144    ex_list = []
145    for start_block, num_blocks in args:
146      ex_list.append(PayloadCheckerTest.NewExtent(start_block, num_blocks))
147    return ex_list
148
149  @staticmethod
150  def AddToMessage(repeated_field, field_vals):
151    for field_val in field_vals:
152      new_field = repeated_field.add()
153      new_field.CopyFrom(field_val)
154
155  def SetupAddElemTest(self, is_present, is_submsg, convert=str,
156                       linebreak=False, indent=0):
157    """Setup for testing of _CheckElem() and its derivatives.
158
159    Args:
160      is_present: Whether or not the element is found in the message.
161      is_submsg: Whether the element is a sub-message itself.
162      convert: A representation conversion function.
163      linebreak: Whether or not a linebreak is to be used in the report.
164      indent: Indentation used for the report.
165
166    Returns:
167      msg: A mock message object.
168      report: A mock report object.
169      subreport: A mock sub-report object.
170      name: An element name to check.
171      val: Expected element value.
172    """
173    name = 'foo'
174    val = 'fake submsg' if is_submsg else 'fake field'
175    subreport = 'fake subreport'
176
177    # Create a mock message.
178    msg = self.mox.CreateMock(update_metadata_pb2._message.Message)
179    msg.HasField(name).AndReturn(is_present)
180    setattr(msg, name, val)
181
182    # Create a mock report.
183    report = self.mox.CreateMock(checker._PayloadReport)
184    if is_present:
185      if is_submsg:
186        report.AddSubReport(name).AndReturn(subreport)
187      else:
188        report.AddField(name, convert(val), linebreak=linebreak, indent=indent)
189
190    self.mox.ReplayAll()
191    return (msg, report, subreport, name, val)
192
193  def DoAddElemTest(self, is_present, is_mandatory, is_submsg, convert,
194                    linebreak, indent):
195    """Parametric testing of _CheckElem().
196
197    Args:
198      is_present: Whether or not the element is found in the message.
199      is_mandatory: Whether or not it's a mandatory element.
200      is_submsg: Whether the element is a sub-message itself.
201      convert: A representation conversion function.
202      linebreak: Whether or not a linebreak is to be used in the report.
203      indent: Indentation used for the report.
204    """
205    msg, report, subreport, name, val = self.SetupAddElemTest(
206        is_present, is_submsg, convert, linebreak, indent)
207
208    args = (msg, name, report, is_mandatory, is_submsg)
209    kwargs = {'convert': convert, 'linebreak': linebreak, 'indent': indent}
210    if is_mandatory and not is_present:
211      self.assertRaises(PayloadError,
212                        checker.PayloadChecker._CheckElem, *args, **kwargs)
213    else:
214      ret_val, ret_subreport = checker.PayloadChecker._CheckElem(*args,
215                                                                 **kwargs)
216      self.assertEquals(val if is_present else None, ret_val)
217      self.assertEquals(subreport if is_present and is_submsg else None,
218                        ret_subreport)
219
220  def DoAddFieldTest(self, is_mandatory, is_present, convert, linebreak,
221                     indent):
222    """Parametric testing of _Check{Mandatory,Optional}Field().
223
224    Args:
225      is_mandatory: Whether we're testing a mandatory call.
226      is_present: Whether or not the element is found in the message.
227      convert: A representation conversion function.
228      linebreak: Whether or not a linebreak is to be used in the report.
229      indent: Indentation used for the report.
230    """
231    msg, report, _, name, val = self.SetupAddElemTest(
232        is_present, False, convert, linebreak, indent)
233
234    # Prepare for invocation of the tested method.
235    args = [msg, name, report]
236    kwargs = {'convert': convert, 'linebreak': linebreak, 'indent': indent}
237    if is_mandatory:
238      args.append('bar')
239      tested_func = checker.PayloadChecker._CheckMandatoryField
240    else:
241      tested_func = checker.PayloadChecker._CheckOptionalField
242
243    # Test the method call.
244    if is_mandatory and not is_present:
245      self.assertRaises(PayloadError, tested_func, *args, **kwargs)
246    else:
247      ret_val = tested_func(*args, **kwargs)
248      self.assertEquals(val if is_present else None, ret_val)
249
250  def DoAddSubMsgTest(self, is_mandatory, is_present):
251    """Parametrized testing of _Check{Mandatory,Optional}SubMsg().
252
253    Args:
254      is_mandatory: Whether we're testing a mandatory call.
255      is_present: Whether or not the element is found in the message.
256    """
257    msg, report, subreport, name, val = self.SetupAddElemTest(is_present, True)
258
259    # Prepare for invocation of the tested method.
260    args = [msg, name, report]
261    if is_mandatory:
262      args.append('bar')
263      tested_func = checker.PayloadChecker._CheckMandatorySubMsg
264    else:
265      tested_func = checker.PayloadChecker._CheckOptionalSubMsg
266
267    # Test the method call.
268    if is_mandatory and not is_present:
269      self.assertRaises(PayloadError, tested_func, *args)
270    else:
271      ret_val, ret_subreport = tested_func(*args)
272      self.assertEquals(val if is_present else None, ret_val)
273      self.assertEquals(subreport if is_present else None, ret_subreport)
274
275  def testCheckPresentIff(self):
276    """Tests _CheckPresentIff()."""
277    self.assertIsNone(checker.PayloadChecker._CheckPresentIff(
278        None, None, 'foo', 'bar', 'baz'))
279    self.assertIsNone(checker.PayloadChecker._CheckPresentIff(
280        'a', 'b', 'foo', 'bar', 'baz'))
281    self.assertRaises(PayloadError, checker.PayloadChecker._CheckPresentIff,
282                      'a', None, 'foo', 'bar', 'baz')
283    self.assertRaises(PayloadError, checker.PayloadChecker._CheckPresentIff,
284                      None, 'b', 'foo', 'bar', 'baz')
285
286  def DoCheckSha256SignatureTest(self, expect_pass, expect_subprocess_call,
287                                 sig_data, sig_asn1_header,
288                                 returned_signed_hash, expected_signed_hash):
289    """Parametric testing of _CheckSha256SignatureTest().
290
291    Args:
292      expect_pass: Whether or not it should pass.
293      expect_subprocess_call: Whether to expect the openssl call to happen.
294      sig_data: The signature raw data.
295      sig_asn1_header: The ASN1 header.
296      returned_signed_hash: The signed hash data retuned by openssl.
297      expected_signed_hash: The signed hash data to compare against.
298    """
299    try:
300      # Stub out the subprocess invocation.
301      self.mox.StubOutWithMock(checker.PayloadChecker, '_Run')
302      if expect_subprocess_call:
303        checker.PayloadChecker._Run(
304            mox.IsA(list), send_data=sig_data).AndReturn(
305                (sig_asn1_header + returned_signed_hash, None))
306
307      self.mox.ReplayAll()
308      if expect_pass:
309        self.assertIsNone(checker.PayloadChecker._CheckSha256Signature(
310            sig_data, 'foo', expected_signed_hash, 'bar'))
311      else:
312        self.assertRaises(PayloadError,
313                          checker.PayloadChecker._CheckSha256Signature,
314                          sig_data, 'foo', expected_signed_hash, 'bar')
315    finally:
316      self.mox.UnsetStubs()
317
318  def testCheckSha256Signature_Pass(self):
319    """Tests _CheckSha256Signature(); pass case."""
320    sig_data = 'fake-signature'.ljust(256)
321    signed_hash = hashlib.sha256('fake-data').digest()
322    self.DoCheckSha256SignatureTest(True, True, sig_data,
323                                    common.SIG_ASN1_HEADER, signed_hash,
324                                    signed_hash)
325
326  def testCheckSha256Signature_FailBadSignature(self):
327    """Tests _CheckSha256Signature(); fails due to malformed signature."""
328    sig_data = 'fake-signature'  # Malformed (not 256 bytes in length).
329    signed_hash = hashlib.sha256('fake-data').digest()
330    self.DoCheckSha256SignatureTest(False, False, sig_data,
331                                    common.SIG_ASN1_HEADER, signed_hash,
332                                    signed_hash)
333
334  def testCheckSha256Signature_FailBadOutputLength(self):
335    """Tests _CheckSha256Signature(); fails due to unexpected output length."""
336    sig_data = 'fake-signature'.ljust(256)
337    signed_hash = 'fake-hash'  # Malformed (not 32 bytes in length).
338    self.DoCheckSha256SignatureTest(False, True, sig_data,
339                                    common.SIG_ASN1_HEADER, signed_hash,
340                                    signed_hash)
341
342  def testCheckSha256Signature_FailBadAsnHeader(self):
343    """Tests _CheckSha256Signature(); fails due to bad ASN1 header."""
344    sig_data = 'fake-signature'.ljust(256)
345    signed_hash = hashlib.sha256('fake-data').digest()
346    bad_asn1_header = 'bad-asn-header'.ljust(len(common.SIG_ASN1_HEADER))
347    self.DoCheckSha256SignatureTest(False, True, sig_data, bad_asn1_header,
348                                    signed_hash, signed_hash)
349
350  def testCheckSha256Signature_FailBadHash(self):
351    """Tests _CheckSha256Signature(); fails due to bad hash returned."""
352    sig_data = 'fake-signature'.ljust(256)
353    expected_signed_hash = hashlib.sha256('fake-data').digest()
354    returned_signed_hash = hashlib.sha256('bad-fake-data').digest()
355    self.DoCheckSha256SignatureTest(False, True, sig_data,
356                                    common.SIG_ASN1_HEADER,
357                                    expected_signed_hash, returned_signed_hash)
358
359  def testCheckBlocksFitLength_Pass(self):
360    """Tests _CheckBlocksFitLength(); pass case."""
361    self.assertIsNone(checker.PayloadChecker._CheckBlocksFitLength(
362        64, 4, 16, 'foo'))
363    self.assertIsNone(checker.PayloadChecker._CheckBlocksFitLength(
364        60, 4, 16, 'foo'))
365    self.assertIsNone(checker.PayloadChecker._CheckBlocksFitLength(
366        49, 4, 16, 'foo'))
367    self.assertIsNone(checker.PayloadChecker._CheckBlocksFitLength(
368        48, 3, 16, 'foo'))
369
370  def testCheckBlocksFitLength_TooManyBlocks(self):
371    """Tests _CheckBlocksFitLength(); fails due to excess blocks."""
372    self.assertRaises(PayloadError,
373                      checker.PayloadChecker._CheckBlocksFitLength,
374                      64, 5, 16, 'foo')
375    self.assertRaises(PayloadError,
376                      checker.PayloadChecker._CheckBlocksFitLength,
377                      60, 5, 16, 'foo')
378    self.assertRaises(PayloadError,
379                      checker.PayloadChecker._CheckBlocksFitLength,
380                      49, 5, 16, 'foo')
381    self.assertRaises(PayloadError,
382                      checker.PayloadChecker._CheckBlocksFitLength,
383                      48, 4, 16, 'foo')
384
385  def testCheckBlocksFitLength_TooFewBlocks(self):
386    """Tests _CheckBlocksFitLength(); fails due to insufficient blocks."""
387    self.assertRaises(PayloadError,
388                      checker.PayloadChecker._CheckBlocksFitLength,
389                      64, 3, 16, 'foo')
390    self.assertRaises(PayloadError,
391                      checker.PayloadChecker._CheckBlocksFitLength,
392                      60, 3, 16, 'foo')
393    self.assertRaises(PayloadError,
394                      checker.PayloadChecker._CheckBlocksFitLength,
395                      49, 3, 16, 'foo')
396    self.assertRaises(PayloadError,
397                      checker.PayloadChecker._CheckBlocksFitLength,
398                      48, 2, 16, 'foo')
399
400  def DoCheckManifestTest(self, fail_mismatched_block_size, fail_bad_sigs,
401                          fail_mismatched_oki_ori, fail_bad_oki, fail_bad_ori,
402                          fail_bad_nki, fail_bad_nri, fail_old_kernel_fs_size,
403                          fail_old_rootfs_fs_size, fail_new_kernel_fs_size,
404                          fail_new_rootfs_fs_size):
405    """Parametric testing of _CheckManifest().
406
407    Args:
408      fail_mismatched_block_size: Simulate a missing block_size field.
409      fail_bad_sigs: Make signatures descriptor inconsistent.
410      fail_mismatched_oki_ori: Make old rootfs/kernel info partially present.
411      fail_bad_oki: Tamper with old kernel info.
412      fail_bad_ori: Tamper with old rootfs info.
413      fail_bad_nki: Tamper with new kernel info.
414      fail_bad_nri: Tamper with new rootfs info.
415      fail_old_kernel_fs_size: Make old kernel fs size too big.
416      fail_old_rootfs_fs_size: Make old rootfs fs size too big.
417      fail_new_kernel_fs_size: Make new kernel fs size too big.
418      fail_new_rootfs_fs_size: Make new rootfs fs size too big.
419    """
420    # Generate a test payload. For this test, we only care about the manifest
421    # and don't need any data blobs, hence we can use a plain paylaod generator
422    # (which also gives us more control on things that can be screwed up).
423    payload_gen = test_utils.PayloadGenerator()
424
425    # Tamper with block size, if required.
426    if fail_mismatched_block_size:
427      payload_gen.SetBlockSize(test_utils.KiB(1))
428    else:
429      payload_gen.SetBlockSize(test_utils.KiB(4))
430
431    # Add some operations.
432    payload_gen.AddOperation(False, common.OpType.MOVE,
433                             src_extents=[(0, 16), (16, 497)],
434                             dst_extents=[(16, 496), (0, 16)])
435    payload_gen.AddOperation(True, common.OpType.MOVE,
436                             src_extents=[(0, 8), (8, 8)],
437                             dst_extents=[(8, 8), (0, 8)])
438
439    # Set an invalid signatures block (offset but no size), if required.
440    if fail_bad_sigs:
441      payload_gen.SetSignatures(32, None)
442
443    # Set partition / filesystem sizes.
444    rootfs_part_size = test_utils.MiB(8)
445    kernel_part_size = test_utils.KiB(512)
446    old_rootfs_fs_size = new_rootfs_fs_size = rootfs_part_size
447    old_kernel_fs_size = new_kernel_fs_size = kernel_part_size
448    if fail_old_kernel_fs_size:
449      old_kernel_fs_size += 100
450    if fail_old_rootfs_fs_size:
451      old_rootfs_fs_size += 100
452    if fail_new_kernel_fs_size:
453      new_kernel_fs_size += 100
454    if fail_new_rootfs_fs_size:
455      new_rootfs_fs_size += 100
456
457    # Add old kernel/rootfs partition info, as required.
458    if fail_mismatched_oki_ori or fail_old_kernel_fs_size or fail_bad_oki:
459      oki_hash = (None if fail_bad_oki
460                  else hashlib.sha256('fake-oki-content').digest())
461      payload_gen.SetPartInfo(True, False, old_kernel_fs_size, oki_hash)
462    if not fail_mismatched_oki_ori and (fail_old_rootfs_fs_size or
463                                        fail_bad_ori):
464      ori_hash = (None if fail_bad_ori
465                  else hashlib.sha256('fake-ori-content').digest())
466      payload_gen.SetPartInfo(False, False, old_rootfs_fs_size, ori_hash)
467
468    # Add new kernel/rootfs partition info.
469    payload_gen.SetPartInfo(
470        True, True, new_kernel_fs_size,
471        None if fail_bad_nki else hashlib.sha256('fake-nki-content').digest())
472    payload_gen.SetPartInfo(
473        False, True, new_rootfs_fs_size,
474        None if fail_bad_nri else hashlib.sha256('fake-nri-content').digest())
475
476    # Set the minor version.
477    payload_gen.SetMinorVersion(0)
478
479    # Create the test object.
480    payload_checker = _GetPayloadChecker(payload_gen.WriteToFile)
481    report = checker._PayloadReport()
482
483    should_fail = (fail_mismatched_block_size or fail_bad_sigs or
484                   fail_mismatched_oki_ori or fail_bad_oki or fail_bad_ori or
485                   fail_bad_nki or fail_bad_nri or fail_old_kernel_fs_size or
486                   fail_old_rootfs_fs_size or fail_new_kernel_fs_size or
487                   fail_new_rootfs_fs_size)
488    part_sizes = {
489        common.ROOTFS: rootfs_part_size,
490        common.KERNEL: kernel_part_size
491    }
492
493    if should_fail:
494      self.assertRaises(PayloadError, payload_checker._CheckManifest, report,
495                        part_sizes)
496    else:
497      self.assertIsNone(payload_checker._CheckManifest(report, part_sizes))
498
499  def testCheckLength(self):
500    """Tests _CheckLength()."""
501    payload_checker = checker.PayloadChecker(self.MockPayload())
502    block_size = payload_checker.block_size
503
504    # Passes.
505    self.assertIsNone(payload_checker._CheckLength(
506        int(3.5 * block_size), 4, 'foo', 'bar'))
507    # Fails, too few blocks.
508    self.assertRaises(PayloadError, payload_checker._CheckLength,
509                      int(3.5 * block_size), 3, 'foo', 'bar')
510    # Fails, too many blocks.
511    self.assertRaises(PayloadError, payload_checker._CheckLength,
512                      int(3.5 * block_size), 5, 'foo', 'bar')
513
514  def testCheckExtents(self):
515    """Tests _CheckExtents()."""
516    payload_checker = checker.PayloadChecker(self.MockPayload())
517    block_size = payload_checker.block_size
518
519    # Passes w/ all real extents.
520    extents = self.NewExtentList((0, 4), (8, 3), (1024, 16))
521    self.assertEquals(
522        23,
523        payload_checker._CheckExtents(extents, (1024 + 16) * block_size,
524                                      collections.defaultdict(int), 'foo'))
525
526    # Passes w/ pseudo-extents (aka sparse holes).
527    extents = self.NewExtentList((0, 4), (common.PSEUDO_EXTENT_MARKER, 5),
528                                 (8, 3))
529    self.assertEquals(
530        12,
531        payload_checker._CheckExtents(extents, (1024 + 16) * block_size,
532                                      collections.defaultdict(int), 'foo',
533                                      allow_pseudo=True))
534
535    # Passes w/ pseudo-extent due to a signature.
536    extents = self.NewExtentList((common.PSEUDO_EXTENT_MARKER, 2))
537    self.assertEquals(
538        2,
539        payload_checker._CheckExtents(extents, (1024 + 16) * block_size,
540                                      collections.defaultdict(int), 'foo',
541                                      allow_signature=True))
542
543    # Fails, extent missing a start block.
544    extents = self.NewExtentList((-1, 4), (8, 3), (1024, 16))
545    self.assertRaises(
546        PayloadError, payload_checker._CheckExtents, extents,
547        (1024 + 16) * block_size, collections.defaultdict(int), 'foo')
548
549    # Fails, extent missing block count.
550    extents = self.NewExtentList((0, -1), (8, 3), (1024, 16))
551    self.assertRaises(
552        PayloadError, payload_checker._CheckExtents, extents,
553        (1024 + 16) * block_size, collections.defaultdict(int), 'foo')
554
555    # Fails, extent has zero blocks.
556    extents = self.NewExtentList((0, 4), (8, 3), (1024, 0))
557    self.assertRaises(
558        PayloadError, payload_checker._CheckExtents, extents,
559        (1024 + 16) * block_size, collections.defaultdict(int), 'foo')
560
561    # Fails, extent exceeds partition boundaries.
562    extents = self.NewExtentList((0, 4), (8, 3), (1024, 16))
563    self.assertRaises(
564        PayloadError, payload_checker._CheckExtents, extents,
565        (1024 + 15) * block_size, collections.defaultdict(int), 'foo')
566
567  def testCheckReplaceOperation(self):
568    """Tests _CheckReplaceOperation() where op.type == REPLACE."""
569    payload_checker = checker.PayloadChecker(self.MockPayload())
570    block_size = payload_checker.block_size
571    data_length = 10000
572
573    op = self.mox.CreateMock(
574        update_metadata_pb2.InstallOperation)
575    op.type = common.OpType.REPLACE
576
577    # Pass.
578    op.src_extents = []
579    self.assertIsNone(
580        payload_checker._CheckReplaceOperation(
581            op, data_length, (data_length + block_size - 1) / block_size,
582            'foo'))
583
584    # Fail, src extents founds.
585    op.src_extents = ['bar']
586    self.assertRaises(
587        PayloadError, payload_checker._CheckReplaceOperation,
588        op, data_length, (data_length + block_size - 1) / block_size, 'foo')
589
590    # Fail, missing data.
591    op.src_extents = []
592    self.assertRaises(
593        PayloadError, payload_checker._CheckReplaceOperation,
594        op, None, (data_length + block_size - 1) / block_size, 'foo')
595
596    # Fail, length / block number mismatch.
597    op.src_extents = ['bar']
598    self.assertRaises(
599        PayloadError, payload_checker._CheckReplaceOperation,
600        op, data_length, (data_length + block_size - 1) / block_size + 1, 'foo')
601
602  def testCheckReplaceBzOperation(self):
603    """Tests _CheckReplaceOperation() where op.type == REPLACE_BZ."""
604    payload_checker = checker.PayloadChecker(self.MockPayload())
605    block_size = payload_checker.block_size
606    data_length = block_size * 3
607
608    op = self.mox.CreateMock(
609        update_metadata_pb2.InstallOperation)
610    op.type = common.OpType.REPLACE_BZ
611
612    # Pass.
613    op.src_extents = []
614    self.assertIsNone(
615        payload_checker._CheckReplaceOperation(
616            op, data_length, (data_length + block_size - 1) / block_size + 5,
617            'foo'))
618
619    # Fail, src extents founds.
620    op.src_extents = ['bar']
621    self.assertRaises(
622        PayloadError, payload_checker._CheckReplaceOperation,
623        op, data_length, (data_length + block_size - 1) / block_size + 5, 'foo')
624
625    # Fail, missing data.
626    op.src_extents = []
627    self.assertRaises(
628        PayloadError, payload_checker._CheckReplaceOperation,
629        op, None, (data_length + block_size - 1) / block_size, 'foo')
630
631    # Fail, too few blocks to justify BZ.
632    op.src_extents = []
633    self.assertRaises(
634        PayloadError, payload_checker._CheckReplaceOperation,
635        op, data_length, (data_length + block_size - 1) / block_size, 'foo')
636
637  def testCheckReplaceXzOperation(self):
638    """Tests _CheckReplaceOperation() where op.type == REPLACE_XZ."""
639    payload_checker = checker.PayloadChecker(self.MockPayload())
640    block_size = payload_checker.block_size
641    data_length = block_size * 3
642
643    op = self.mox.CreateMock(
644        update_metadata_pb2.InstallOperation)
645    op.type = common.OpType.REPLACE_XZ
646
647    # Pass.
648    op.src_extents = []
649    self.assertIsNone(
650        payload_checker._CheckReplaceOperation(
651            op, data_length, (data_length + block_size - 1) / block_size + 5,
652            'foo'))
653
654    # Fail, src extents founds.
655    op.src_extents = ['bar']
656    self.assertRaises(
657        PayloadError, payload_checker._CheckReplaceOperation,
658        op, data_length, (data_length + block_size - 1) / block_size + 5, 'foo')
659
660    # Fail, missing data.
661    op.src_extents = []
662    self.assertRaises(
663        PayloadError, payload_checker._CheckReplaceOperation,
664        op, None, (data_length + block_size - 1) / block_size, 'foo')
665
666    # Fail, too few blocks to justify XZ.
667    op.src_extents = []
668    self.assertRaises(
669        PayloadError, payload_checker._CheckReplaceOperation,
670        op, data_length, (data_length + block_size - 1) / block_size, 'foo')
671
672  def testCheckMoveOperation_Pass(self):
673    """Tests _CheckMoveOperation(); pass case."""
674    payload_checker = checker.PayloadChecker(self.MockPayload())
675    op = update_metadata_pb2.InstallOperation()
676    op.type = common.OpType.MOVE
677
678    self.AddToMessage(op.src_extents,
679                      self.NewExtentList((1, 4), (12, 2), (1024, 128)))
680    self.AddToMessage(op.dst_extents,
681                      self.NewExtentList((16, 128), (512, 6)))
682    self.assertIsNone(
683        payload_checker._CheckMoveOperation(op, None, 134, 134, 'foo'))
684
685  def testCheckMoveOperation_FailContainsData(self):
686    """Tests _CheckMoveOperation(); fails, message contains data."""
687    payload_checker = checker.PayloadChecker(self.MockPayload())
688    op = update_metadata_pb2.InstallOperation()
689    op.type = common.OpType.MOVE
690
691    self.AddToMessage(op.src_extents,
692                      self.NewExtentList((1, 4), (12, 2), (1024, 128)))
693    self.AddToMessage(op.dst_extents,
694                      self.NewExtentList((16, 128), (512, 6)))
695    self.assertRaises(
696        PayloadError, payload_checker._CheckMoveOperation,
697        op, 1024, 134, 134, 'foo')
698
699  def testCheckMoveOperation_FailInsufficientSrcBlocks(self):
700    """Tests _CheckMoveOperation(); fails, not enough actual src blocks."""
701    payload_checker = checker.PayloadChecker(self.MockPayload())
702    op = update_metadata_pb2.InstallOperation()
703    op.type = common.OpType.MOVE
704
705    self.AddToMessage(op.src_extents,
706                      self.NewExtentList((1, 4), (12, 2), (1024, 127)))
707    self.AddToMessage(op.dst_extents,
708                      self.NewExtentList((16, 128), (512, 6)))
709    self.assertRaises(
710        PayloadError, payload_checker._CheckMoveOperation,
711        op, None, 134, 134, 'foo')
712
713  def testCheckMoveOperation_FailInsufficientDstBlocks(self):
714    """Tests _CheckMoveOperation(); fails, not enough actual dst blocks."""
715    payload_checker = checker.PayloadChecker(self.MockPayload())
716    op = update_metadata_pb2.InstallOperation()
717    op.type = common.OpType.MOVE
718
719    self.AddToMessage(op.src_extents,
720                      self.NewExtentList((1, 4), (12, 2), (1024, 128)))
721    self.AddToMessage(op.dst_extents,
722                      self.NewExtentList((16, 128), (512, 5)))
723    self.assertRaises(
724        PayloadError, payload_checker._CheckMoveOperation,
725        op, None, 134, 134, 'foo')
726
727  def testCheckMoveOperation_FailExcessSrcBlocks(self):
728    """Tests _CheckMoveOperation(); fails, too many actual src blocks."""
729    payload_checker = checker.PayloadChecker(self.MockPayload())
730    op = update_metadata_pb2.InstallOperation()
731    op.type = common.OpType.MOVE
732
733    self.AddToMessage(op.src_extents,
734                      self.NewExtentList((1, 4), (12, 2), (1024, 128)))
735    self.AddToMessage(op.dst_extents,
736                      self.NewExtentList((16, 128), (512, 5)))
737    self.assertRaises(
738        PayloadError, payload_checker._CheckMoveOperation,
739        op, None, 134, 134, 'foo')
740    self.AddToMessage(op.src_extents,
741                      self.NewExtentList((1, 4), (12, 2), (1024, 129)))
742    self.AddToMessage(op.dst_extents,
743                      self.NewExtentList((16, 128), (512, 6)))
744    self.assertRaises(
745        PayloadError, payload_checker._CheckMoveOperation,
746        op, None, 134, 134, 'foo')
747
748  def testCheckMoveOperation_FailExcessDstBlocks(self):
749    """Tests _CheckMoveOperation(); fails, too many actual dst blocks."""
750    payload_checker = checker.PayloadChecker(self.MockPayload())
751    op = update_metadata_pb2.InstallOperation()
752    op.type = common.OpType.MOVE
753
754    self.AddToMessage(op.src_extents,
755                      self.NewExtentList((1, 4), (12, 2), (1024, 128)))
756    self.AddToMessage(op.dst_extents,
757                      self.NewExtentList((16, 128), (512, 7)))
758    self.assertRaises(
759        PayloadError, payload_checker._CheckMoveOperation,
760        op, None, 134, 134, 'foo')
761
762  def testCheckMoveOperation_FailStagnantBlocks(self):
763    """Tests _CheckMoveOperation(); fails, there are blocks that do not move."""
764    payload_checker = checker.PayloadChecker(self.MockPayload())
765    op = update_metadata_pb2.InstallOperation()
766    op.type = common.OpType.MOVE
767
768    self.AddToMessage(op.src_extents,
769                      self.NewExtentList((1, 4), (12, 2), (1024, 128)))
770    self.AddToMessage(op.dst_extents,
771                      self.NewExtentList((8, 128), (512, 6)))
772    self.assertRaises(
773        PayloadError, payload_checker._CheckMoveOperation,
774        op, None, 134, 134, 'foo')
775
776  def testCheckMoveOperation_FailZeroStartBlock(self):
777    """Tests _CheckMoveOperation(); fails, has extent with start block 0."""
778    payload_checker = checker.PayloadChecker(self.MockPayload())
779    op = update_metadata_pb2.InstallOperation()
780    op.type = common.OpType.MOVE
781
782    self.AddToMessage(op.src_extents,
783                      self.NewExtentList((0, 4), (12, 2), (1024, 128)))
784    self.AddToMessage(op.dst_extents,
785                      self.NewExtentList((8, 128), (512, 6)))
786    self.assertRaises(
787        PayloadError, payload_checker._CheckMoveOperation,
788        op, None, 134, 134, 'foo')
789
790    self.AddToMessage(op.src_extents,
791                      self.NewExtentList((1, 4), (12, 2), (1024, 128)))
792    self.AddToMessage(op.dst_extents,
793                      self.NewExtentList((0, 128), (512, 6)))
794    self.assertRaises(
795        PayloadError, payload_checker._CheckMoveOperation,
796        op, None, 134, 134, 'foo')
797
798  def testCheckAnyDiff(self):
799    """Tests _CheckAnyDiffOperation()."""
800    payload_checker = checker.PayloadChecker(self.MockPayload())
801    op = update_metadata_pb2.InstallOperation()
802
803    # Pass.
804    self.assertIsNone(
805        payload_checker._CheckAnyDiffOperation(op, 10000, 3, 'foo'))
806
807    # Fail, missing data blob.
808    self.assertRaises(
809        PayloadError, payload_checker._CheckAnyDiffOperation,
810        op, None, 3, 'foo')
811
812    # Fail, too big of a diff blob (unjustified).
813    self.assertRaises(
814        PayloadError, payload_checker._CheckAnyDiffOperation,
815        op, 10000, 2, 'foo')
816
817  def testCheckSourceCopyOperation_Pass(self):
818    """Tests _CheckSourceCopyOperation(); pass case."""
819    payload_checker = checker.PayloadChecker(self.MockPayload())
820    self.assertIsNone(
821        payload_checker._CheckSourceCopyOperation(None, 134, 134, 'foo'))
822
823  def testCheckSourceCopyOperation_FailContainsData(self):
824    """Tests _CheckSourceCopyOperation(); message contains data."""
825    payload_checker = checker.PayloadChecker(self.MockPayload())
826    self.assertRaises(PayloadError, payload_checker._CheckSourceCopyOperation,
827                      134, 0, 0, 'foo')
828
829  def testCheckSourceCopyOperation_FailBlockCountsMismatch(self):
830    """Tests _CheckSourceCopyOperation(); src and dst block totals not equal."""
831    payload_checker = checker.PayloadChecker(self.MockPayload())
832    self.assertRaises(PayloadError, payload_checker._CheckSourceCopyOperation,
833                      None, 0, 1, 'foo')
834
835  def DoCheckOperationTest(self, op_type_name, is_last, allow_signature,
836                           allow_unhashed, fail_src_extents, fail_dst_extents,
837                           fail_mismatched_data_offset_length,
838                           fail_missing_dst_extents, fail_src_length,
839                           fail_dst_length, fail_data_hash,
840                           fail_prev_data_offset, fail_bad_minor_version):
841    """Parametric testing of _CheckOperation().
842
843    Args:
844      op_type_name: 'REPLACE', 'REPLACE_BZ', 'REPLACE_XZ', 'MOVE', 'BSDIFF',
845        'SOURCE_COPY', 'SOURCE_BSDIFF', BROTLI_BSDIFF or 'PUFFDIFF'.
846      is_last: Whether we're testing the last operation in a sequence.
847      allow_signature: Whether we're testing a signature-capable operation.
848      allow_unhashed: Whether we're allowing to not hash the data.
849      fail_src_extents: Tamper with src extents.
850      fail_dst_extents: Tamper with dst extents.
851      fail_mismatched_data_offset_length: Make data_{offset,length}
852        inconsistent.
853      fail_missing_dst_extents: Do not include dst extents.
854      fail_src_length: Make src length inconsistent.
855      fail_dst_length: Make dst length inconsistent.
856      fail_data_hash: Tamper with the data blob hash.
857      fail_prev_data_offset: Make data space uses incontiguous.
858      fail_bad_minor_version: Make minor version incompatible with op.
859    """
860    op_type = _OpTypeByName(op_type_name)
861
862    # Create the test object.
863    payload = self.MockPayload()
864    payload_checker = checker.PayloadChecker(payload,
865                                             allow_unhashed=allow_unhashed)
866    block_size = payload_checker.block_size
867
868    # Create auxiliary arguments.
869    old_part_size = test_utils.MiB(4)
870    new_part_size = test_utils.MiB(8)
871    old_block_counters = array.array(
872        'B', [0] * ((old_part_size + block_size - 1) / block_size))
873    new_block_counters = array.array(
874        'B', [0] * ((new_part_size + block_size - 1) / block_size))
875    prev_data_offset = 1876
876    blob_hash_counts = collections.defaultdict(int)
877
878    # Create the operation object for the test.
879    op = update_metadata_pb2.InstallOperation()
880    op.type = op_type
881
882    total_src_blocks = 0
883    if op_type in (common.OpType.MOVE, common.OpType.BSDIFF,
884                   common.OpType.SOURCE_COPY, common.OpType.SOURCE_BSDIFF,
885                   common.OpType.PUFFDIFF, common.OpType.BROTLI_BSDIFF):
886      if fail_src_extents:
887        self.AddToMessage(op.src_extents,
888                          self.NewExtentList((1, 0)))
889      else:
890        self.AddToMessage(op.src_extents,
891                          self.NewExtentList((1, 16)))
892        total_src_blocks = 16
893
894    # TODO(tbrindus): add major version 2 tests.
895    payload_checker.major_version = common.CHROMEOS_MAJOR_PAYLOAD_VERSION
896    if op_type in (common.OpType.REPLACE, common.OpType.REPLACE_BZ):
897      payload_checker.minor_version = 0
898    elif op_type in (common.OpType.MOVE, common.OpType.BSDIFF):
899      payload_checker.minor_version = 2 if fail_bad_minor_version else 1
900    elif op_type in (common.OpType.SOURCE_COPY, common.OpType.SOURCE_BSDIFF):
901      payload_checker.minor_version = 1 if fail_bad_minor_version else 2
902    if op_type == common.OpType.REPLACE_XZ:
903      payload_checker.minor_version = 2 if fail_bad_minor_version else 3
904    elif op_type in (common.OpType.ZERO, common.OpType.DISCARD,
905                     common.OpType.BROTLI_BSDIFF):
906      payload_checker.minor_version = 3 if fail_bad_minor_version else 4
907    elif op_type == common.OpType.PUFFDIFF:
908      payload_checker.minor_version = 4 if fail_bad_minor_version else 5
909
910    if op_type not in (common.OpType.MOVE, common.OpType.SOURCE_COPY):
911      if not fail_mismatched_data_offset_length:
912        op.data_length = 16 * block_size - 8
913      if fail_prev_data_offset:
914        op.data_offset = prev_data_offset + 16
915      else:
916        op.data_offset = prev_data_offset
917
918      fake_data = 'fake-data'.ljust(op.data_length)
919      if not (allow_unhashed or (is_last and allow_signature and
920                                 op_type == common.OpType.REPLACE)):
921        if not fail_data_hash:
922          # Create a valid data blob hash.
923          op.data_sha256_hash = hashlib.sha256(fake_data).digest()
924          payload.ReadDataBlob(op.data_offset, op.data_length).AndReturn(
925              fake_data)
926
927      elif fail_data_hash:
928        # Create an invalid data blob hash.
929        op.data_sha256_hash = hashlib.sha256(
930            fake_data.replace(' ', '-')).digest()
931        payload.ReadDataBlob(op.data_offset, op.data_length).AndReturn(
932            fake_data)
933
934    total_dst_blocks = 0
935    if not fail_missing_dst_extents:
936      total_dst_blocks = 16
937      if fail_dst_extents:
938        self.AddToMessage(op.dst_extents,
939                          self.NewExtentList((4, 16), (32, 0)))
940      else:
941        self.AddToMessage(op.dst_extents,
942                          self.NewExtentList((4, 8), (64, 8)))
943
944    if total_src_blocks:
945      if fail_src_length:
946        op.src_length = total_src_blocks * block_size + 8
947      elif (op_type in (common.OpType.MOVE, common.OpType.BSDIFF,
948                        common.OpType.SOURCE_BSDIFF) and
949            payload_checker.minor_version <= 3):
950        op.src_length = total_src_blocks * block_size
951    elif fail_src_length:
952      # Add an orphaned src_length.
953      op.src_length = 16
954
955    if total_dst_blocks:
956      if fail_dst_length:
957        op.dst_length = total_dst_blocks * block_size + 8
958      elif (op_type in (common.OpType.MOVE, common.OpType.BSDIFF,
959                        common.OpType.SOURCE_BSDIFF) and
960            payload_checker.minor_version <= 3):
961        op.dst_length = total_dst_blocks * block_size
962
963    self.mox.ReplayAll()
964    should_fail = (fail_src_extents or fail_dst_extents or
965                   fail_mismatched_data_offset_length or
966                   fail_missing_dst_extents or fail_src_length or
967                   fail_dst_length or fail_data_hash or fail_prev_data_offset or
968                   fail_bad_minor_version)
969    args = (op, 'foo', is_last, old_block_counters, new_block_counters,
970            old_part_size, new_part_size, prev_data_offset, allow_signature,
971            blob_hash_counts)
972    if should_fail:
973      self.assertRaises(PayloadError, payload_checker._CheckOperation, *args)
974    else:
975      self.assertEqual(op.data_length if op.HasField('data_length') else 0,
976                       payload_checker._CheckOperation(*args))
977
978  def testAllocBlockCounters(self):
979    """Tests _CheckMoveOperation()."""
980    payload_checker = checker.PayloadChecker(self.MockPayload())
981    block_size = payload_checker.block_size
982
983    # Check allocation for block-aligned partition size, ensure it's integers.
984    result = payload_checker._AllocBlockCounters(16 * block_size)
985    self.assertEqual(16, len(result))
986    self.assertEqual(int, type(result[0]))
987
988    # Check allocation of unaligned partition sizes.
989    result = payload_checker._AllocBlockCounters(16 * block_size - 1)
990    self.assertEqual(16, len(result))
991    result = payload_checker._AllocBlockCounters(16 * block_size + 1)
992    self.assertEqual(17, len(result))
993
994  def DoCheckOperationsTest(self, fail_nonexhaustive_full_update):
995    """Tests _CheckOperations()."""
996    # Generate a test payload. For this test, we only care about one
997    # (arbitrary) set of operations, so we'll only be generating kernel and
998    # test with them.
999    payload_gen = test_utils.PayloadGenerator()
1000
1001    block_size = test_utils.KiB(4)
1002    payload_gen.SetBlockSize(block_size)
1003
1004    rootfs_part_size = test_utils.MiB(8)
1005
1006    # Fake rootfs operations in a full update, tampered with as required.
1007    rootfs_op_type = common.OpType.REPLACE
1008    rootfs_data_length = rootfs_part_size
1009    if fail_nonexhaustive_full_update:
1010      rootfs_data_length -= block_size
1011
1012    payload_gen.AddOperation(False, rootfs_op_type,
1013                             dst_extents=[(0, rootfs_data_length / block_size)],
1014                             data_offset=0,
1015                             data_length=rootfs_data_length)
1016
1017    # Create the test object.
1018    payload_checker = _GetPayloadChecker(payload_gen.WriteToFile,
1019                                         checker_init_dargs={
1020                                             'allow_unhashed': True})
1021    payload_checker.payload_type = checker._TYPE_FULL
1022    report = checker._PayloadReport()
1023
1024    args = (payload_checker.payload.manifest.install_operations, report, 'foo',
1025            0, rootfs_part_size, rootfs_part_size, rootfs_part_size, 0, False)
1026    if fail_nonexhaustive_full_update:
1027      self.assertRaises(PayloadError, payload_checker._CheckOperations, *args)
1028    else:
1029      self.assertEqual(rootfs_data_length,
1030                       payload_checker._CheckOperations(*args))
1031
1032  def DoCheckSignaturesTest(self, fail_empty_sigs_blob, fail_missing_pseudo_op,
1033                            fail_mismatched_pseudo_op, fail_sig_missing_fields,
1034                            fail_unknown_sig_version, fail_incorrect_sig):
1035    """Tests _CheckSignatures()."""
1036    # Generate a test payload. For this test, we only care about the signature
1037    # block and how it relates to the payload hash. Therefore, we're generating
1038    # a random (otherwise useless) payload for this purpose.
1039    payload_gen = test_utils.EnhancedPayloadGenerator()
1040    block_size = test_utils.KiB(4)
1041    payload_gen.SetBlockSize(block_size)
1042    rootfs_part_size = test_utils.MiB(2)
1043    kernel_part_size = test_utils.KiB(16)
1044    payload_gen.SetPartInfo(False, True, rootfs_part_size,
1045                            hashlib.sha256('fake-new-rootfs-content').digest())
1046    payload_gen.SetPartInfo(True, True, kernel_part_size,
1047                            hashlib.sha256('fake-new-kernel-content').digest())
1048    payload_gen.SetMinorVersion(0)
1049    payload_gen.AddOperationWithData(
1050        False, common.OpType.REPLACE,
1051        dst_extents=[(0, rootfs_part_size / block_size)],
1052        data_blob=os.urandom(rootfs_part_size))
1053
1054    do_forge_pseudo_op = (fail_missing_pseudo_op or fail_mismatched_pseudo_op)
1055    do_forge_sigs_data = (do_forge_pseudo_op or fail_empty_sigs_blob or
1056                          fail_sig_missing_fields or fail_unknown_sig_version
1057                          or fail_incorrect_sig)
1058
1059    sigs_data = None
1060    if do_forge_sigs_data:
1061      sigs_gen = test_utils.SignaturesGenerator()
1062      if not fail_empty_sigs_blob:
1063        if fail_sig_missing_fields:
1064          sig_data = None
1065        else:
1066          sig_data = test_utils.SignSha256('fake-payload-content',
1067                                           test_utils._PRIVKEY_FILE_NAME)
1068        sigs_gen.AddSig(5 if fail_unknown_sig_version else 1, sig_data)
1069
1070      sigs_data = sigs_gen.ToBinary()
1071      payload_gen.SetSignatures(payload_gen.curr_offset, len(sigs_data))
1072
1073    if do_forge_pseudo_op:
1074      assert sigs_data is not None, 'should have forged signatures blob by now'
1075      sigs_len = len(sigs_data)
1076      payload_gen.AddOperation(
1077          False, common.OpType.REPLACE,
1078          data_offset=payload_gen.curr_offset / 2,
1079          data_length=sigs_len / 2,
1080          dst_extents=[(0, (sigs_len / 2 + block_size - 1) / block_size)])
1081
1082    # Generate payload (complete w/ signature) and create the test object.
1083    payload_checker = _GetPayloadChecker(
1084        payload_gen.WriteToFileWithData,
1085        payload_gen_dargs={
1086            'sigs_data': sigs_data,
1087            'privkey_file_name': test_utils._PRIVKEY_FILE_NAME,
1088            'do_add_pseudo_operation': not do_forge_pseudo_op})
1089    payload_checker.payload_type = checker._TYPE_FULL
1090    report = checker._PayloadReport()
1091
1092    # We have to check the manifest first in order to set signature attributes.
1093    payload_checker._CheckManifest(report, {
1094        common.ROOTFS: rootfs_part_size,
1095        common.KERNEL: kernel_part_size
1096    })
1097
1098    should_fail = (fail_empty_sigs_blob or fail_missing_pseudo_op or
1099                   fail_mismatched_pseudo_op or fail_sig_missing_fields or
1100                   fail_unknown_sig_version or fail_incorrect_sig)
1101    args = (report, test_utils._PUBKEY_FILE_NAME)
1102    if should_fail:
1103      self.assertRaises(PayloadError, payload_checker._CheckSignatures, *args)
1104    else:
1105      self.assertIsNone(payload_checker._CheckSignatures(*args))
1106
1107  def DoCheckManifestMinorVersionTest(self, minor_version, payload_type):
1108    """Parametric testing for CheckManifestMinorVersion().
1109
1110    Args:
1111      minor_version: The payload minor version to test with.
1112      payload_type: The type of the payload we're testing, delta or full.
1113    """
1114    # Create the test object.
1115    payload = self.MockPayload()
1116    payload.manifest.minor_version = minor_version
1117    payload_checker = checker.PayloadChecker(payload)
1118    payload_checker.payload_type = payload_type
1119    report = checker._PayloadReport()
1120
1121    should_succeed = (
1122        (minor_version == 0 and payload_type == checker._TYPE_FULL) or
1123        (minor_version == 1 and payload_type == checker._TYPE_DELTA) or
1124        (minor_version == 2 and payload_type == checker._TYPE_DELTA) or
1125        (minor_version == 3 and payload_type == checker._TYPE_DELTA) or
1126        (minor_version == 4 and payload_type == checker._TYPE_DELTA) or
1127        (minor_version == 5 and payload_type == checker._TYPE_DELTA))
1128    args = (report,)
1129
1130    if should_succeed:
1131      self.assertIsNone(payload_checker._CheckManifestMinorVersion(*args))
1132    else:
1133      self.assertRaises(PayloadError,
1134                        payload_checker._CheckManifestMinorVersion, *args)
1135
1136  def DoRunTest(self, rootfs_part_size_provided, kernel_part_size_provided,
1137                fail_wrong_payload_type, fail_invalid_block_size,
1138                fail_mismatched_metadata_size, fail_mismatched_block_size,
1139                fail_excess_data, fail_rootfs_part_size_exceeded,
1140                fail_kernel_part_size_exceeded):
1141    """Tests Run()."""
1142    # Generate a test payload. For this test, we generate a full update that
1143    # has sample kernel and rootfs operations. Since most testing is done with
1144    # internal PayloadChecker methods that are tested elsewhere, here we only
1145    # tamper with what's actually being manipulated and/or tested in the Run()
1146    # method itself. Note that the checker doesn't verify partition hashes, so
1147    # they're safe to fake.
1148    payload_gen = test_utils.EnhancedPayloadGenerator()
1149    block_size = test_utils.KiB(4)
1150    payload_gen.SetBlockSize(block_size)
1151    kernel_filesystem_size = test_utils.KiB(16)
1152    rootfs_filesystem_size = test_utils.MiB(2)
1153    payload_gen.SetPartInfo(False, True, rootfs_filesystem_size,
1154                            hashlib.sha256('fake-new-rootfs-content').digest())
1155    payload_gen.SetPartInfo(True, True, kernel_filesystem_size,
1156                            hashlib.sha256('fake-new-kernel-content').digest())
1157    payload_gen.SetMinorVersion(0)
1158
1159    rootfs_part_size = 0
1160    if rootfs_part_size_provided:
1161      rootfs_part_size = rootfs_filesystem_size + block_size
1162    rootfs_op_size = rootfs_part_size or rootfs_filesystem_size
1163    if fail_rootfs_part_size_exceeded:
1164      rootfs_op_size += block_size
1165    payload_gen.AddOperationWithData(
1166        False, common.OpType.REPLACE,
1167        dst_extents=[(0, rootfs_op_size / block_size)],
1168        data_blob=os.urandom(rootfs_op_size))
1169
1170    kernel_part_size = 0
1171    if kernel_part_size_provided:
1172      kernel_part_size = kernel_filesystem_size + block_size
1173    kernel_op_size = kernel_part_size or kernel_filesystem_size
1174    if fail_kernel_part_size_exceeded:
1175      kernel_op_size += block_size
1176    payload_gen.AddOperationWithData(
1177        True, common.OpType.REPLACE,
1178        dst_extents=[(0, kernel_op_size / block_size)],
1179        data_blob=os.urandom(kernel_op_size))
1180
1181    # Generate payload (complete w/ signature) and create the test object.
1182    if fail_invalid_block_size:
1183      use_block_size = block_size + 5  # Not a power of two.
1184    elif fail_mismatched_block_size:
1185      use_block_size = block_size * 2  # Different that payload stated.
1186    else:
1187      use_block_size = block_size
1188
1189    # For the unittests 246 is the value that generated for the payload.
1190    metadata_size = 246
1191    if fail_mismatched_metadata_size:
1192      metadata_size += 1
1193
1194    kwargs = {
1195        'payload_gen_dargs': {
1196            'privkey_file_name': test_utils._PRIVKEY_FILE_NAME,
1197            'do_add_pseudo_operation': True,
1198            'is_pseudo_in_kernel': True,
1199            'padding': os.urandom(1024) if fail_excess_data else None},
1200        'checker_init_dargs': {
1201            'assert_type': 'delta' if fail_wrong_payload_type else 'full',
1202            'block_size': use_block_size}}
1203    if fail_invalid_block_size:
1204      self.assertRaises(PayloadError, _GetPayloadChecker,
1205                        payload_gen.WriteToFileWithData, **kwargs)
1206    else:
1207      payload_checker = _GetPayloadChecker(payload_gen.WriteToFileWithData,
1208                                           **kwargs)
1209
1210      kwargs = {
1211          'pubkey_file_name': test_utils._PUBKEY_FILE_NAME,
1212          'metadata_size': metadata_size,
1213          'part_sizes': {
1214              common.KERNEL: kernel_part_size,
1215              common.ROOTFS: rootfs_part_size}}
1216
1217      should_fail = (fail_wrong_payload_type or fail_mismatched_block_size or
1218                     fail_mismatched_metadata_size or fail_excess_data or
1219                     fail_rootfs_part_size_exceeded or
1220                     fail_kernel_part_size_exceeded)
1221      if should_fail:
1222        self.assertRaises(PayloadError, payload_checker.Run, **kwargs)
1223      else:
1224        self.assertIsNone(payload_checker.Run(**kwargs))
1225
1226# This implements a generic API, hence the occasional unused args.
1227# pylint: disable=W0613
1228def ValidateCheckOperationTest(op_type_name, is_last, allow_signature,
1229                               allow_unhashed, fail_src_extents,
1230                               fail_dst_extents,
1231                               fail_mismatched_data_offset_length,
1232                               fail_missing_dst_extents, fail_src_length,
1233                               fail_dst_length, fail_data_hash,
1234                               fail_prev_data_offset, fail_bad_minor_version):
1235  """Returns True iff the combination of arguments represents a valid test."""
1236  op_type = _OpTypeByName(op_type_name)
1237
1238  # REPLACE/REPLACE_BZ/REPLACE_XZ operations don't read data from src
1239  # partition. They are compatible with all valid minor versions, so we don't
1240  # need to check that.
1241  if (op_type in (common.OpType.REPLACE, common.OpType.REPLACE_BZ,
1242                  common.OpType.REPLACE_XZ) and (fail_src_extents or
1243                                                 fail_src_length or
1244                                                 fail_bad_minor_version)):
1245    return False
1246
1247  # MOVE and SOURCE_COPY operations don't carry data.
1248  if (op_type in (common.OpType.MOVE, common.OpType.SOURCE_COPY) and (
1249      fail_mismatched_data_offset_length or fail_data_hash or
1250      fail_prev_data_offset)):
1251    return False
1252
1253  return True
1254
1255
1256def TestMethodBody(run_method_name, run_dargs):
1257  """Returns a function that invokes a named method with named arguments."""
1258  return lambda self: getattr(self, run_method_name)(**run_dargs)
1259
1260
1261def AddParametricTests(tested_method_name, arg_space, validate_func=None):
1262  """Enumerates and adds specific parametric tests to PayloadCheckerTest.
1263
1264  This function enumerates a space of test parameters (defined by arg_space),
1265  then binds a new, unique method name in PayloadCheckerTest to a test function
1266  that gets handed the said parameters. This is a preferable approach to doing
1267  the enumeration and invocation during the tests because this way each test is
1268  treated as a complete run by the unittest framework, and so benefits from the
1269  usual setUp/tearDown mechanics.
1270
1271  Args:
1272    tested_method_name: Name of the tested PayloadChecker method.
1273    arg_space: A dictionary containing variables (keys) and lists of values
1274               (values) associated with them.
1275    validate_func: A function used for validating test argument combinations.
1276  """
1277  for value_tuple in itertools.product(*arg_space.itervalues()):
1278    run_dargs = dict(zip(arg_space.iterkeys(), value_tuple))
1279    if validate_func and not validate_func(**run_dargs):
1280      continue
1281    run_method_name = 'Do%sTest' % tested_method_name
1282    test_method_name = 'test%s' % tested_method_name
1283    for arg_key, arg_val in run_dargs.iteritems():
1284      if arg_val or type(arg_val) is int:
1285        test_method_name += '__%s=%s' % (arg_key, arg_val)
1286    setattr(PayloadCheckerTest, test_method_name,
1287            TestMethodBody(run_method_name, run_dargs))
1288
1289
1290def AddAllParametricTests():
1291  """Enumerates and adds all parametric tests to PayloadCheckerTest."""
1292  # Add all _CheckElem() test cases.
1293  AddParametricTests('AddElem',
1294                     {'linebreak': (True, False),
1295                      'indent': (0, 1, 2),
1296                      'convert': (str, lambda s: s[::-1]),
1297                      'is_present': (True, False),
1298                      'is_mandatory': (True, False),
1299                      'is_submsg': (True, False)})
1300
1301  # Add all _Add{Mandatory,Optional}Field tests.
1302  AddParametricTests('AddField',
1303                     {'is_mandatory': (True, False),
1304                      'linebreak': (True, False),
1305                      'indent': (0, 1, 2),
1306                      'convert': (str, lambda s: s[::-1]),
1307                      'is_present': (True, False)})
1308
1309  # Add all _Add{Mandatory,Optional}SubMsg tests.
1310  AddParametricTests('AddSubMsg',
1311                     {'is_mandatory': (True, False),
1312                      'is_present': (True, False)})
1313
1314  # Add all _CheckManifest() test cases.
1315  AddParametricTests('CheckManifest',
1316                     {'fail_mismatched_block_size': (True, False),
1317                      'fail_bad_sigs': (True, False),
1318                      'fail_mismatched_oki_ori': (True, False),
1319                      'fail_bad_oki': (True, False),
1320                      'fail_bad_ori': (True, False),
1321                      'fail_bad_nki': (True, False),
1322                      'fail_bad_nri': (True, False),
1323                      'fail_old_kernel_fs_size': (True, False),
1324                      'fail_old_rootfs_fs_size': (True, False),
1325                      'fail_new_kernel_fs_size': (True, False),
1326                      'fail_new_rootfs_fs_size': (True, False)})
1327
1328  # Add all _CheckOperation() test cases.
1329  AddParametricTests('CheckOperation',
1330                     {'op_type_name': ('REPLACE', 'REPLACE_BZ', 'REPLACE_XZ',
1331                                       'MOVE', 'BSDIFF', 'SOURCE_COPY',
1332                                       'SOURCE_BSDIFF', 'PUFFDIFF',
1333                                       'BROTLI_BSDIFF'),
1334                      'is_last': (True, False),
1335                      'allow_signature': (True, False),
1336                      'allow_unhashed': (True, False),
1337                      'fail_src_extents': (True, False),
1338                      'fail_dst_extents': (True, False),
1339                      'fail_mismatched_data_offset_length': (True, False),
1340                      'fail_missing_dst_extents': (True, False),
1341                      'fail_src_length': (True, False),
1342                      'fail_dst_length': (True, False),
1343                      'fail_data_hash': (True, False),
1344                      'fail_prev_data_offset': (True, False),
1345                      'fail_bad_minor_version': (True, False)},
1346                     validate_func=ValidateCheckOperationTest)
1347
1348  # Add all _CheckOperations() test cases.
1349  AddParametricTests('CheckOperations',
1350                     {'fail_nonexhaustive_full_update': (True, False)})
1351
1352  # Add all _CheckOperations() test cases.
1353  AddParametricTests('CheckSignatures',
1354                     {'fail_empty_sigs_blob': (True, False),
1355                      'fail_missing_pseudo_op': (True, False),
1356                      'fail_mismatched_pseudo_op': (True, False),
1357                      'fail_sig_missing_fields': (True, False),
1358                      'fail_unknown_sig_version': (True, False),
1359                      'fail_incorrect_sig': (True, False)})
1360
1361  # Add all _CheckManifestMinorVersion() test cases.
1362  AddParametricTests('CheckManifestMinorVersion',
1363                     {'minor_version': (None, 0, 1, 2, 3, 4, 5, 555),
1364                      'payload_type': (checker._TYPE_FULL,
1365                                       checker._TYPE_DELTA)})
1366
1367  # Add all Run() test cases.
1368  AddParametricTests('Run',
1369                     {'rootfs_part_size_provided': (True, False),
1370                      'kernel_part_size_provided': (True, False),
1371                      'fail_wrong_payload_type': (True, False),
1372                      'fail_invalid_block_size': (True, False),
1373                      'fail_mismatched_metadata_size': (True, False),
1374                      'fail_mismatched_block_size': (True, False),
1375                      'fail_excess_data': (True, False),
1376                      'fail_rootfs_part_size_exceeded': (True, False),
1377                      'fail_kernel_part_size_exceeded': (True, False)})
1378
1379
1380if __name__ == '__main__':
1381  AddAllParametricTests()
1382  unittest.main()
1383