• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python2
2#
3# Copyright (c) 2013 The Chromium OS Authors. All rights reserved.
4# Use of this source code is governed by a BSD-style license that can be
5# found in the LICENSE file.
6
7"""Unit testing checker.py."""
8
9from __future__ import print_function
10
11import array
12import collections
13import cStringIO
14import hashlib
15import itertools
16import os
17import unittest
18
19# pylint cannot find mox.
20# pylint: disable=F0401
21import mox
22
23import checker
24import common
25import payload as update_payload  # Avoid name conflicts later.
26import test_utils
27import update_metadata_pb2
28
29
30def _OpTypeByName(op_name):
31  op_name_to_type = {
32      'REPLACE': common.OpType.REPLACE,
33      'REPLACE_BZ': common.OpType.REPLACE_BZ,
34      'MOVE': common.OpType.MOVE,
35      'BSDIFF': common.OpType.BSDIFF,
36      'SOURCE_COPY': common.OpType.SOURCE_COPY,
37      'SOURCE_BSDIFF': common.OpType.SOURCE_BSDIFF,
38      'ZERO': common.OpType.ZERO,
39      'DISCARD': common.OpType.DISCARD,
40      'REPLACE_XZ': common.OpType.REPLACE_XZ,
41      'IMGDIFF': common.OpType.IMGDIFF,
42  }
43  return op_name_to_type[op_name]
44
45
46def _GetPayloadChecker(payload_gen_write_to_file_func, payload_gen_dargs=None,
47                       checker_init_dargs=None):
48  """Returns a payload checker from a given payload generator."""
49  if payload_gen_dargs is None:
50    payload_gen_dargs = {}
51  if checker_init_dargs is None:
52    checker_init_dargs = {}
53
54  payload_file = cStringIO.StringIO()
55  payload_gen_write_to_file_func(payload_file, **payload_gen_dargs)
56  payload_file.seek(0)
57  payload = update_payload.Payload(payload_file)
58  payload.Init()
59  return checker.PayloadChecker(payload, **checker_init_dargs)
60
61
62def _GetPayloadCheckerWithData(payload_gen):
63  """Returns a payload checker from a given payload generator."""
64  payload_file = cStringIO.StringIO()
65  payload_gen.WriteToFile(payload_file)
66  payload_file.seek(0)
67  payload = update_payload.Payload(payload_file)
68  payload.Init()
69  return checker.PayloadChecker(payload)
70
71
72# This class doesn't need an __init__().
73# pylint: disable=W0232
74# Unit testing is all about running protected methods.
75# pylint: disable=W0212
76# Don't bark about missing members of classes you cannot import.
77# pylint: disable=E1101
78class PayloadCheckerTest(mox.MoxTestBase):
79  """Tests the PayloadChecker class.
80
81  In addition to ordinary testFoo() methods, which are automatically invoked by
82  the unittest framework, in this class we make use of DoBarTest() calls that
83  implement parametric tests of certain features. In order to invoke each test,
84  which embodies a unique combination of parameter values, as a complete unit
85  test, we perform explicit enumeration of the parameter space and create
86  individual invocation contexts for each, which are then bound as
87  testBar__param1=val1__param2=val2(). The enumeration of parameter spaces for
88  all such tests is done in AddAllParametricTests().
89  """
90
91  def MockPayload(self):
92    """Create a mock payload object, complete with a mock manifest."""
93    payload = self.mox.CreateMock(update_payload.Payload)
94    payload.is_init = True
95    payload.manifest = self.mox.CreateMock(
96        update_metadata_pb2.DeltaArchiveManifest)
97    return payload
98
99  @staticmethod
100  def NewExtent(start_block, num_blocks):
101    """Returns an Extent message.
102
103    Each of the provided fields is set iff it is >= 0; otherwise, it's left at
104    its default state.
105
106    Args:
107      start_block: The starting block of the extent.
108      num_blocks: The number of blocks in the extent.
109
110    Returns:
111      An Extent message.
112    """
113    ex = update_metadata_pb2.Extent()
114    if start_block >= 0:
115      ex.start_block = start_block
116    if num_blocks >= 0:
117      ex.num_blocks = num_blocks
118    return ex
119
120  @staticmethod
121  def NewExtentList(*args):
122    """Returns an list of extents.
123
124    Args:
125      *args: (start_block, num_blocks) pairs defining the extents.
126
127    Returns:
128      A list of Extent objects.
129    """
130    ex_list = []
131    for start_block, num_blocks in args:
132      ex_list.append(PayloadCheckerTest.NewExtent(start_block, num_blocks))
133    return ex_list
134
135  @staticmethod
136  def AddToMessage(repeated_field, field_vals):
137    for field_val in field_vals:
138      new_field = repeated_field.add()
139      new_field.CopyFrom(field_val)
140
141  def SetupAddElemTest(self, is_present, is_submsg, convert=str,
142                       linebreak=False, indent=0):
143    """Setup for testing of _CheckElem() and its derivatives.
144
145    Args:
146      is_present: Whether or not the element is found in the message.
147      is_submsg: Whether the element is a sub-message itself.
148      convert: A representation conversion function.
149      linebreak: Whether or not a linebreak is to be used in the report.
150      indent: Indentation used for the report.
151
152    Returns:
153      msg: A mock message object.
154      report: A mock report object.
155      subreport: A mock sub-report object.
156      name: An element name to check.
157      val: Expected element value.
158    """
159    name = 'foo'
160    val = 'fake submsg' if is_submsg else 'fake field'
161    subreport = 'fake subreport'
162
163    # Create a mock message.
164    msg = self.mox.CreateMock(update_metadata_pb2._message.Message)
165    msg.HasField(name).AndReturn(is_present)
166    setattr(msg, name, val)
167
168    # Create a mock report.
169    report = self.mox.CreateMock(checker._PayloadReport)
170    if is_present:
171      if is_submsg:
172        report.AddSubReport(name).AndReturn(subreport)
173      else:
174        report.AddField(name, convert(val), linebreak=linebreak, indent=indent)
175
176    self.mox.ReplayAll()
177    return (msg, report, subreport, name, val)
178
179  def DoAddElemTest(self, is_present, is_mandatory, is_submsg, convert,
180                    linebreak, indent):
181    """Parametric testing of _CheckElem().
182
183    Args:
184      is_present: Whether or not the element is found in the message.
185      is_mandatory: Whether or not it's a mandatory element.
186      is_submsg: Whether the element is a sub-message itself.
187      convert: A representation conversion function.
188      linebreak: Whether or not a linebreak is to be used in the report.
189      indent: Indentation used for the report.
190    """
191    msg, report, subreport, name, val = self.SetupAddElemTest(
192        is_present, is_submsg, convert, linebreak, indent)
193
194    args = (msg, name, report, is_mandatory, is_submsg)
195    kwargs = {'convert': convert, 'linebreak': linebreak, 'indent': indent}
196    if is_mandatory and not is_present:
197      self.assertRaises(update_payload.PayloadError,
198                        checker.PayloadChecker._CheckElem, *args, **kwargs)
199    else:
200      ret_val, ret_subreport = checker.PayloadChecker._CheckElem(*args,
201                                                                 **kwargs)
202      self.assertEquals(val if is_present else None, ret_val)
203      self.assertEquals(subreport if is_present and is_submsg else None,
204                        ret_subreport)
205
206  def DoAddFieldTest(self, is_mandatory, is_present, convert, linebreak,
207                     indent):
208    """Parametric testing of _Check{Mandatory,Optional}Field().
209
210    Args:
211      is_mandatory: Whether we're testing a mandatory call.
212      is_present: Whether or not the element is found in the message.
213      convert: A representation conversion function.
214      linebreak: Whether or not a linebreak is to be used in the report.
215      indent: Indentation used for the report.
216    """
217    msg, report, _, name, val = self.SetupAddElemTest(
218        is_present, False, convert, linebreak, indent)
219
220    # Prepare for invocation of the tested method.
221    args = [msg, name, report]
222    kwargs = {'convert': convert, 'linebreak': linebreak, 'indent': indent}
223    if is_mandatory:
224      args.append('bar')
225      tested_func = checker.PayloadChecker._CheckMandatoryField
226    else:
227      tested_func = checker.PayloadChecker._CheckOptionalField
228
229    # Test the method call.
230    if is_mandatory and not is_present:
231      self.assertRaises(update_payload.PayloadError, tested_func, *args,
232                        **kwargs)
233    else:
234      ret_val = tested_func(*args, **kwargs)
235      self.assertEquals(val if is_present else None, ret_val)
236
237  def DoAddSubMsgTest(self, is_mandatory, is_present):
238    """Parametrized testing of _Check{Mandatory,Optional}SubMsg().
239
240    Args:
241      is_mandatory: Whether we're testing a mandatory call.
242      is_present: Whether or not the element is found in the message.
243    """
244    msg, report, subreport, name, val = self.SetupAddElemTest(is_present, True)
245
246    # Prepare for invocation of the tested method.
247    args = [msg, name, report]
248    if is_mandatory:
249      args.append('bar')
250      tested_func = checker.PayloadChecker._CheckMandatorySubMsg
251    else:
252      tested_func = checker.PayloadChecker._CheckOptionalSubMsg
253
254    # Test the method call.
255    if is_mandatory and not is_present:
256      self.assertRaises(update_payload.PayloadError, tested_func, *args)
257    else:
258      ret_val, ret_subreport = tested_func(*args)
259      self.assertEquals(val if is_present else None, ret_val)
260      self.assertEquals(subreport if is_present else None, ret_subreport)
261
262  def testCheckPresentIff(self):
263    """Tests _CheckPresentIff()."""
264    self.assertIsNone(checker.PayloadChecker._CheckPresentIff(
265        None, None, 'foo', 'bar', 'baz'))
266    self.assertIsNone(checker.PayloadChecker._CheckPresentIff(
267        'a', 'b', 'foo', 'bar', 'baz'))
268    self.assertRaises(update_payload.PayloadError,
269                      checker.PayloadChecker._CheckPresentIff,
270                      'a', None, 'foo', 'bar', 'baz')
271    self.assertRaises(update_payload.PayloadError,
272                      checker.PayloadChecker._CheckPresentIff,
273                      None, 'b', 'foo', 'bar', 'baz')
274
275  def DoCheckSha256SignatureTest(self, expect_pass, expect_subprocess_call,
276                                 sig_data, sig_asn1_header,
277                                 returned_signed_hash, expected_signed_hash):
278    """Parametric testing of _CheckSha256SignatureTest().
279
280    Args:
281      expect_pass: Whether or not it should pass.
282      expect_subprocess_call: Whether to expect the openssl call to happen.
283      sig_data: The signature raw data.
284      sig_asn1_header: The ASN1 header.
285      returned_signed_hash: The signed hash data retuned by openssl.
286      expected_signed_hash: The signed hash data to compare against.
287    """
288    try:
289      # Stub out the subprocess invocation.
290      self.mox.StubOutWithMock(checker.PayloadChecker, '_Run')
291      if expect_subprocess_call:
292        checker.PayloadChecker._Run(
293            mox.IsA(list), send_data=sig_data).AndReturn(
294                (sig_asn1_header + returned_signed_hash, None))
295
296      self.mox.ReplayAll()
297      if expect_pass:
298        self.assertIsNone(checker.PayloadChecker._CheckSha256Signature(
299            sig_data, 'foo', expected_signed_hash, 'bar'))
300      else:
301        self.assertRaises(update_payload.PayloadError,
302                          checker.PayloadChecker._CheckSha256Signature,
303                          sig_data, 'foo', expected_signed_hash, 'bar')
304    finally:
305      self.mox.UnsetStubs()
306
307  def testCheckSha256Signature_Pass(self):
308    """Tests _CheckSha256Signature(); pass case."""
309    sig_data = 'fake-signature'.ljust(256)
310    signed_hash = hashlib.sha256('fake-data').digest()
311    self.DoCheckSha256SignatureTest(True, True, sig_data,
312                                    common.SIG_ASN1_HEADER, signed_hash,
313                                    signed_hash)
314
315  def testCheckSha256Signature_FailBadSignature(self):
316    """Tests _CheckSha256Signature(); fails due to malformed signature."""
317    sig_data = 'fake-signature'  # Malformed (not 256 bytes in length).
318    signed_hash = hashlib.sha256('fake-data').digest()
319    self.DoCheckSha256SignatureTest(False, False, sig_data,
320                                    common.SIG_ASN1_HEADER, signed_hash,
321                                    signed_hash)
322
323  def testCheckSha256Signature_FailBadOutputLength(self):
324    """Tests _CheckSha256Signature(); fails due to unexpected output length."""
325    sig_data = 'fake-signature'.ljust(256)
326    signed_hash = 'fake-hash'  # Malformed (not 32 bytes in length).
327    self.DoCheckSha256SignatureTest(False, True, sig_data,
328                                    common.SIG_ASN1_HEADER, signed_hash,
329                                    signed_hash)
330
331  def testCheckSha256Signature_FailBadAsnHeader(self):
332    """Tests _CheckSha256Signature(); fails due to bad ASN1 header."""
333    sig_data = 'fake-signature'.ljust(256)
334    signed_hash = hashlib.sha256('fake-data').digest()
335    bad_asn1_header = 'bad-asn-header'.ljust(len(common.SIG_ASN1_HEADER))
336    self.DoCheckSha256SignatureTest(False, True, sig_data, bad_asn1_header,
337                                    signed_hash, signed_hash)
338
339  def testCheckSha256Signature_FailBadHash(self):
340    """Tests _CheckSha256Signature(); fails due to bad hash returned."""
341    sig_data = 'fake-signature'.ljust(256)
342    expected_signed_hash = hashlib.sha256('fake-data').digest()
343    returned_signed_hash = hashlib.sha256('bad-fake-data').digest()
344    self.DoCheckSha256SignatureTest(False, True, sig_data,
345                                    common.SIG_ASN1_HEADER,
346                                    expected_signed_hash, returned_signed_hash)
347
348  def testCheckBlocksFitLength_Pass(self):
349    """Tests _CheckBlocksFitLength(); pass case."""
350    self.assertIsNone(checker.PayloadChecker._CheckBlocksFitLength(
351        64, 4, 16, 'foo'))
352    self.assertIsNone(checker.PayloadChecker._CheckBlocksFitLength(
353        60, 4, 16, 'foo'))
354    self.assertIsNone(checker.PayloadChecker._CheckBlocksFitLength(
355        49, 4, 16, 'foo'))
356    self.assertIsNone(checker.PayloadChecker._CheckBlocksFitLength(
357        48, 3, 16, 'foo'))
358
359  def testCheckBlocksFitLength_TooManyBlocks(self):
360    """Tests _CheckBlocksFitLength(); fails due to excess blocks."""
361    self.assertRaises(update_payload.PayloadError,
362                      checker.PayloadChecker._CheckBlocksFitLength,
363                      64, 5, 16, 'foo')
364    self.assertRaises(update_payload.PayloadError,
365                      checker.PayloadChecker._CheckBlocksFitLength,
366                      60, 5, 16, 'foo')
367    self.assertRaises(update_payload.PayloadError,
368                      checker.PayloadChecker._CheckBlocksFitLength,
369                      49, 5, 16, 'foo')
370    self.assertRaises(update_payload.PayloadError,
371                      checker.PayloadChecker._CheckBlocksFitLength,
372                      48, 4, 16, 'foo')
373
374  def testCheckBlocksFitLength_TooFewBlocks(self):
375    """Tests _CheckBlocksFitLength(); fails due to insufficient blocks."""
376    self.assertRaises(update_payload.PayloadError,
377                      checker.PayloadChecker._CheckBlocksFitLength,
378                      64, 3, 16, 'foo')
379    self.assertRaises(update_payload.PayloadError,
380                      checker.PayloadChecker._CheckBlocksFitLength,
381                      60, 3, 16, 'foo')
382    self.assertRaises(update_payload.PayloadError,
383                      checker.PayloadChecker._CheckBlocksFitLength,
384                      49, 3, 16, 'foo')
385    self.assertRaises(update_payload.PayloadError,
386                      checker.PayloadChecker._CheckBlocksFitLength,
387                      48, 2, 16, 'foo')
388
389  def DoCheckManifestTest(self, fail_mismatched_block_size, fail_bad_sigs,
390                          fail_mismatched_oki_ori, fail_bad_oki, fail_bad_ori,
391                          fail_bad_nki, fail_bad_nri, fail_old_kernel_fs_size,
392                          fail_old_rootfs_fs_size, fail_new_kernel_fs_size,
393                          fail_new_rootfs_fs_size):
394    """Parametric testing of _CheckManifest().
395
396    Args:
397      fail_mismatched_block_size: Simulate a missing block_size field.
398      fail_bad_sigs: Make signatures descriptor inconsistent.
399      fail_mismatched_oki_ori: Make old rootfs/kernel info partially present.
400      fail_bad_oki: Tamper with old kernel info.
401      fail_bad_ori: Tamper with old rootfs info.
402      fail_bad_nki: Tamper with new kernel info.
403      fail_bad_nri: Tamper with new rootfs info.
404      fail_old_kernel_fs_size: Make old kernel fs size too big.
405      fail_old_rootfs_fs_size: Make old rootfs fs size too big.
406      fail_new_kernel_fs_size: Make new kernel fs size too big.
407      fail_new_rootfs_fs_size: Make new rootfs fs size too big.
408    """
409    # Generate a test payload. For this test, we only care about the manifest
410    # and don't need any data blobs, hence we can use a plain paylaod generator
411    # (which also gives us more control on things that can be screwed up).
412    payload_gen = test_utils.PayloadGenerator()
413
414    # Tamper with block size, if required.
415    if fail_mismatched_block_size:
416      payload_gen.SetBlockSize(test_utils.KiB(1))
417    else:
418      payload_gen.SetBlockSize(test_utils.KiB(4))
419
420    # Add some operations.
421    payload_gen.AddOperation(False, common.OpType.MOVE,
422                             src_extents=[(0, 16), (16, 497)],
423                             dst_extents=[(16, 496), (0, 16)])
424    payload_gen.AddOperation(True, common.OpType.MOVE,
425                             src_extents=[(0, 8), (8, 8)],
426                             dst_extents=[(8, 8), (0, 8)])
427
428    # Set an invalid signatures block (offset but no size), if required.
429    if fail_bad_sigs:
430      payload_gen.SetSignatures(32, None)
431
432    # Set partition / filesystem sizes.
433    rootfs_part_size = test_utils.MiB(8)
434    kernel_part_size = test_utils.KiB(512)
435    old_rootfs_fs_size = new_rootfs_fs_size = rootfs_part_size
436    old_kernel_fs_size = new_kernel_fs_size = kernel_part_size
437    if fail_old_kernel_fs_size:
438      old_kernel_fs_size += 100
439    if fail_old_rootfs_fs_size:
440      old_rootfs_fs_size += 100
441    if fail_new_kernel_fs_size:
442      new_kernel_fs_size += 100
443    if fail_new_rootfs_fs_size:
444      new_rootfs_fs_size += 100
445
446    # Add old kernel/rootfs partition info, as required.
447    if fail_mismatched_oki_ori or fail_old_kernel_fs_size or fail_bad_oki:
448      oki_hash = (None if fail_bad_oki
449                  else hashlib.sha256('fake-oki-content').digest())
450      payload_gen.SetPartInfo(True, False, old_kernel_fs_size, oki_hash)
451    if not fail_mismatched_oki_ori and (fail_old_rootfs_fs_size or
452                                        fail_bad_ori):
453      ori_hash = (None if fail_bad_ori
454                  else hashlib.sha256('fake-ori-content').digest())
455      payload_gen.SetPartInfo(False, False, old_rootfs_fs_size, ori_hash)
456
457    # Add new kernel/rootfs partition info.
458    payload_gen.SetPartInfo(
459        True, True, new_kernel_fs_size,
460        None if fail_bad_nki else hashlib.sha256('fake-nki-content').digest())
461    payload_gen.SetPartInfo(
462        False, True, new_rootfs_fs_size,
463        None if fail_bad_nri else hashlib.sha256('fake-nri-content').digest())
464
465    # Set the minor version.
466    payload_gen.SetMinorVersion(0)
467
468    # Create the test object.
469    payload_checker = _GetPayloadChecker(payload_gen.WriteToFile)
470    report = checker._PayloadReport()
471
472    should_fail = (fail_mismatched_block_size or fail_bad_sigs or
473                   fail_mismatched_oki_ori or fail_bad_oki or fail_bad_ori or
474                   fail_bad_nki or fail_bad_nri or fail_old_kernel_fs_size or
475                   fail_old_rootfs_fs_size or fail_new_kernel_fs_size or
476                   fail_new_rootfs_fs_size)
477    if should_fail:
478      self.assertRaises(update_payload.PayloadError,
479                        payload_checker._CheckManifest, report,
480                        rootfs_part_size, kernel_part_size)
481    else:
482      self.assertIsNone(payload_checker._CheckManifest(report,
483                                                       rootfs_part_size,
484                                                       kernel_part_size))
485
486  def testCheckLength(self):
487    """Tests _CheckLength()."""
488    payload_checker = checker.PayloadChecker(self.MockPayload())
489    block_size = payload_checker.block_size
490
491    # Passes.
492    self.assertIsNone(payload_checker._CheckLength(
493        int(3.5 * block_size), 4, 'foo', 'bar'))
494    # Fails, too few blocks.
495    self.assertRaises(update_payload.PayloadError,
496                      payload_checker._CheckLength,
497                      int(3.5 * block_size), 3, 'foo', 'bar')
498    # Fails, too many blocks.
499    self.assertRaises(update_payload.PayloadError,
500                      payload_checker._CheckLength,
501                      int(3.5 * block_size), 5, 'foo', 'bar')
502
503  def testCheckExtents(self):
504    """Tests _CheckExtents()."""
505    payload_checker = checker.PayloadChecker(self.MockPayload())
506    block_size = payload_checker.block_size
507
508    # Passes w/ all real extents.
509    extents = self.NewExtentList((0, 4), (8, 3), (1024, 16))
510    self.assertEquals(
511        23,
512        payload_checker._CheckExtents(extents, (1024 + 16) * block_size,
513                                      collections.defaultdict(int), 'foo'))
514
515    # Passes w/ pseudo-extents (aka sparse holes).
516    extents = self.NewExtentList((0, 4), (common.PSEUDO_EXTENT_MARKER, 5),
517                                 (8, 3))
518    self.assertEquals(
519        12,
520        payload_checker._CheckExtents(extents, (1024 + 16) * block_size,
521                                      collections.defaultdict(int), 'foo',
522                                      allow_pseudo=True))
523
524    # Passes w/ pseudo-extent due to a signature.
525    extents = self.NewExtentList((common.PSEUDO_EXTENT_MARKER, 2))
526    self.assertEquals(
527        2,
528        payload_checker._CheckExtents(extents, (1024 + 16) * block_size,
529                                      collections.defaultdict(int), 'foo',
530                                      allow_signature=True))
531
532    # Fails, extent missing a start block.
533    extents = self.NewExtentList((-1, 4), (8, 3), (1024, 16))
534    self.assertRaises(
535        update_payload.PayloadError, payload_checker._CheckExtents,
536        extents, (1024 + 16) * block_size, collections.defaultdict(int),
537        'foo')
538
539    # Fails, extent missing block count.
540    extents = self.NewExtentList((0, -1), (8, 3), (1024, 16))
541    self.assertRaises(
542        update_payload.PayloadError, payload_checker._CheckExtents,
543        extents, (1024 + 16) * block_size, collections.defaultdict(int),
544        'foo')
545
546    # Fails, extent has zero blocks.
547    extents = self.NewExtentList((0, 4), (8, 3), (1024, 0))
548    self.assertRaises(
549        update_payload.PayloadError, payload_checker._CheckExtents,
550        extents, (1024 + 16) * block_size, collections.defaultdict(int),
551        'foo')
552
553    # Fails, extent exceeds partition boundaries.
554    extents = self.NewExtentList((0, 4), (8, 3), (1024, 16))
555    self.assertRaises(
556        update_payload.PayloadError, payload_checker._CheckExtents,
557        extents, (1024 + 15) * block_size, collections.defaultdict(int),
558        'foo')
559
560  def testCheckReplaceOperation(self):
561    """Tests _CheckReplaceOperation() where op.type == REPLACE."""
562    payload_checker = checker.PayloadChecker(self.MockPayload())
563    block_size = payload_checker.block_size
564    data_length = 10000
565
566    op = self.mox.CreateMock(
567        update_metadata_pb2.InstallOperation)
568    op.type = common.OpType.REPLACE
569
570    # Pass.
571    op.src_extents = []
572    self.assertIsNone(
573        payload_checker._CheckReplaceOperation(
574            op, data_length, (data_length + block_size - 1) / block_size,
575            'foo'))
576
577    # Fail, src extents founds.
578    op.src_extents = ['bar']
579    self.assertRaises(
580        update_payload.PayloadError,
581        payload_checker._CheckReplaceOperation,
582        op, data_length, (data_length + block_size - 1) / block_size, 'foo')
583
584    # Fail, missing data.
585    op.src_extents = []
586    self.assertRaises(
587        update_payload.PayloadError,
588        payload_checker._CheckReplaceOperation,
589        op, None, (data_length + block_size - 1) / block_size, 'foo')
590
591    # Fail, length / block number mismatch.
592    op.src_extents = ['bar']
593    self.assertRaises(
594        update_payload.PayloadError,
595        payload_checker._CheckReplaceOperation,
596        op, data_length, (data_length + block_size - 1) / block_size + 1, 'foo')
597
598  def testCheckReplaceBzOperation(self):
599    """Tests _CheckReplaceOperation() where op.type == REPLACE_BZ."""
600    payload_checker = checker.PayloadChecker(self.MockPayload())
601    block_size = payload_checker.block_size
602    data_length = block_size * 3
603
604    op = self.mox.CreateMock(
605        update_metadata_pb2.InstallOperation)
606    op.type = common.OpType.REPLACE_BZ
607
608    # Pass.
609    op.src_extents = []
610    self.assertIsNone(
611        payload_checker._CheckReplaceOperation(
612            op, data_length, (data_length + block_size - 1) / block_size + 5,
613            'foo'))
614
615    # Fail, src extents founds.
616    op.src_extents = ['bar']
617    self.assertRaises(
618        update_payload.PayloadError,
619        payload_checker._CheckReplaceOperation,
620        op, data_length, (data_length + block_size - 1) / block_size + 5, 'foo')
621
622    # Fail, missing data.
623    op.src_extents = []
624    self.assertRaises(
625        update_payload.PayloadError,
626        payload_checker._CheckReplaceOperation,
627        op, None, (data_length + block_size - 1) / block_size, 'foo')
628
629    # Fail, too few blocks to justify BZ.
630    op.src_extents = []
631    self.assertRaises(
632        update_payload.PayloadError,
633        payload_checker._CheckReplaceOperation,
634        op, data_length, (data_length + block_size - 1) / block_size, 'foo')
635
636  def testCheckMoveOperation_Pass(self):
637    """Tests _CheckMoveOperation(); pass case."""
638    payload_checker = checker.PayloadChecker(self.MockPayload())
639    op = update_metadata_pb2.InstallOperation()
640    op.type = common.OpType.MOVE
641
642    self.AddToMessage(op.src_extents,
643                      self.NewExtentList((1, 4), (12, 2), (1024, 128)))
644    self.AddToMessage(op.dst_extents,
645                      self.NewExtentList((16, 128), (512, 6)))
646    self.assertIsNone(
647        payload_checker._CheckMoveOperation(op, None, 134, 134, 'foo'))
648
649  def testCheckMoveOperation_FailContainsData(self):
650    """Tests _CheckMoveOperation(); fails, message contains data."""
651    payload_checker = checker.PayloadChecker(self.MockPayload())
652    op = update_metadata_pb2.InstallOperation()
653    op.type = common.OpType.MOVE
654
655    self.AddToMessage(op.src_extents,
656                      self.NewExtentList((1, 4), (12, 2), (1024, 128)))
657    self.AddToMessage(op.dst_extents,
658                      self.NewExtentList((16, 128), (512, 6)))
659    self.assertRaises(
660        update_payload.PayloadError,
661        payload_checker._CheckMoveOperation,
662        op, 1024, 134, 134, 'foo')
663
664  def testCheckMoveOperation_FailInsufficientSrcBlocks(self):
665    """Tests _CheckMoveOperation(); fails, not enough actual src blocks."""
666    payload_checker = checker.PayloadChecker(self.MockPayload())
667    op = update_metadata_pb2.InstallOperation()
668    op.type = common.OpType.MOVE
669
670    self.AddToMessage(op.src_extents,
671                      self.NewExtentList((1, 4), (12, 2), (1024, 127)))
672    self.AddToMessage(op.dst_extents,
673                      self.NewExtentList((16, 128), (512, 6)))
674    self.assertRaises(
675        update_payload.PayloadError,
676        payload_checker._CheckMoveOperation,
677        op, None, 134, 134, 'foo')
678
679  def testCheckMoveOperation_FailInsufficientDstBlocks(self):
680    """Tests _CheckMoveOperation(); fails, not enough actual dst blocks."""
681    payload_checker = checker.PayloadChecker(self.MockPayload())
682    op = update_metadata_pb2.InstallOperation()
683    op.type = common.OpType.MOVE
684
685    self.AddToMessage(op.src_extents,
686                      self.NewExtentList((1, 4), (12, 2), (1024, 128)))
687    self.AddToMessage(op.dst_extents,
688                      self.NewExtentList((16, 128), (512, 5)))
689    self.assertRaises(
690        update_payload.PayloadError,
691        payload_checker._CheckMoveOperation,
692        op, None, 134, 134, 'foo')
693
694  def testCheckMoveOperation_FailExcessSrcBlocks(self):
695    """Tests _CheckMoveOperation(); fails, too many actual src blocks."""
696    payload_checker = checker.PayloadChecker(self.MockPayload())
697    op = update_metadata_pb2.InstallOperation()
698    op.type = common.OpType.MOVE
699
700    self.AddToMessage(op.src_extents,
701                      self.NewExtentList((1, 4), (12, 2), (1024, 128)))
702    self.AddToMessage(op.dst_extents,
703                      self.NewExtentList((16, 128), (512, 5)))
704    self.assertRaises(
705        update_payload.PayloadError,
706        payload_checker._CheckMoveOperation,
707        op, None, 134, 134, 'foo')
708    self.AddToMessage(op.src_extents,
709                      self.NewExtentList((1, 4), (12, 2), (1024, 129)))
710    self.AddToMessage(op.dst_extents,
711                      self.NewExtentList((16, 128), (512, 6)))
712    self.assertRaises(
713        update_payload.PayloadError,
714        payload_checker._CheckMoveOperation,
715        op, None, 134, 134, 'foo')
716
717  def testCheckMoveOperation_FailExcessDstBlocks(self):
718    """Tests _CheckMoveOperation(); fails, too many actual dst blocks."""
719    payload_checker = checker.PayloadChecker(self.MockPayload())
720    op = update_metadata_pb2.InstallOperation()
721    op.type = common.OpType.MOVE
722
723    self.AddToMessage(op.src_extents,
724                      self.NewExtentList((1, 4), (12, 2), (1024, 128)))
725    self.AddToMessage(op.dst_extents,
726                      self.NewExtentList((16, 128), (512, 7)))
727    self.assertRaises(
728        update_payload.PayloadError,
729        payload_checker._CheckMoveOperation,
730        op, None, 134, 134, 'foo')
731
732  def testCheckMoveOperation_FailStagnantBlocks(self):
733    """Tests _CheckMoveOperation(); fails, there are blocks that do not move."""
734    payload_checker = checker.PayloadChecker(self.MockPayload())
735    op = update_metadata_pb2.InstallOperation()
736    op.type = common.OpType.MOVE
737
738    self.AddToMessage(op.src_extents,
739                      self.NewExtentList((1, 4), (12, 2), (1024, 128)))
740    self.AddToMessage(op.dst_extents,
741                      self.NewExtentList((8, 128), (512, 6)))
742    self.assertRaises(
743        update_payload.PayloadError,
744        payload_checker._CheckMoveOperation,
745        op, None, 134, 134, 'foo')
746
747  def testCheckMoveOperation_FailZeroStartBlock(self):
748    """Tests _CheckMoveOperation(); fails, has extent with start block 0."""
749    payload_checker = checker.PayloadChecker(self.MockPayload())
750    op = update_metadata_pb2.InstallOperation()
751    op.type = common.OpType.MOVE
752
753    self.AddToMessage(op.src_extents,
754                      self.NewExtentList((0, 4), (12, 2), (1024, 128)))
755    self.AddToMessage(op.dst_extents,
756                      self.NewExtentList((8, 128), (512, 6)))
757    self.assertRaises(
758        update_payload.PayloadError,
759        payload_checker._CheckMoveOperation,
760        op, None, 134, 134, 'foo')
761
762    self.AddToMessage(op.src_extents,
763                      self.NewExtentList((1, 4), (12, 2), (1024, 128)))
764    self.AddToMessage(op.dst_extents,
765                      self.NewExtentList((0, 128), (512, 6)))
766    self.assertRaises(
767        update_payload.PayloadError,
768        payload_checker._CheckMoveOperation,
769        op, None, 134, 134, 'foo')
770
771  def testCheckAnyDiff(self):
772    """Tests _CheckAnyDiffOperation()."""
773    payload_checker = checker.PayloadChecker(self.MockPayload())
774
775    # Pass.
776    self.assertIsNone(
777        payload_checker._CheckAnyDiffOperation(10000, 3, 'foo'))
778
779    # Fail, missing data blob.
780    self.assertRaises(
781        update_payload.PayloadError,
782        payload_checker._CheckAnyDiffOperation,
783        None, 3, 'foo')
784
785    # Fail, too big of a diff blob (unjustified).
786    self.assertRaises(
787        update_payload.PayloadError,
788        payload_checker._CheckAnyDiffOperation,
789        10000, 2, 'foo')
790
791  def testCheckSourceCopyOperation_Pass(self):
792    """Tests _CheckSourceCopyOperation(); pass case."""
793    payload_checker = checker.PayloadChecker(self.MockPayload())
794    self.assertIsNone(
795        payload_checker._CheckSourceCopyOperation(None, 134, 134, 'foo'))
796
797  def testCheckSourceCopyOperation_FailContainsData(self):
798    """Tests _CheckSourceCopyOperation(); message contains data."""
799    payload_checker = checker.PayloadChecker(self.MockPayload())
800    self.assertRaises(update_payload.PayloadError,
801                      payload_checker._CheckSourceCopyOperation,
802                      134, 0, 0, 'foo')
803
804  def testCheckSourceCopyOperation_FailBlockCountsMismatch(self):
805    """Tests _CheckSourceCopyOperation(); src and dst block totals not equal."""
806    payload_checker = checker.PayloadChecker(self.MockPayload())
807    self.assertRaises(update_payload.PayloadError,
808                      payload_checker._CheckSourceCopyOperation,
809                      None, 0, 1, 'foo')
810
811  def DoCheckOperationTest(self, op_type_name, is_last, allow_signature,
812                           allow_unhashed, fail_src_extents, fail_dst_extents,
813                           fail_mismatched_data_offset_length,
814                           fail_missing_dst_extents, fail_src_length,
815                           fail_dst_length, fail_data_hash,
816                           fail_prev_data_offset, fail_bad_minor_version):
817    """Parametric testing of _CheckOperation().
818
819    Args:
820      op_type_name: 'REPLACE', 'REPLACE_BZ', 'MOVE', 'BSDIFF', 'SOURCE_COPY',
821        or 'SOURCE_BSDIFF'.
822      is_last: Whether we're testing the last operation in a sequence.
823      allow_signature: Whether we're testing a signature-capable operation.
824      allow_unhashed: Whether we're allowing to not hash the data.
825      fail_src_extents: Tamper with src extents.
826      fail_dst_extents: Tamper with dst extents.
827      fail_mismatched_data_offset_length: Make data_{offset,length}
828        inconsistent.
829      fail_missing_dst_extents: Do not include dst extents.
830      fail_src_length: Make src length inconsistent.
831      fail_dst_length: Make dst length inconsistent.
832      fail_data_hash: Tamper with the data blob hash.
833      fail_prev_data_offset: Make data space uses incontiguous.
834      fail_bad_minor_version: Make minor version incompatible with op.
835    """
836    op_type = _OpTypeByName(op_type_name)
837
838    # Create the test object.
839    payload = self.MockPayload()
840    payload_checker = checker.PayloadChecker(payload,
841                                             allow_unhashed=allow_unhashed)
842    block_size = payload_checker.block_size
843
844    # Create auxiliary arguments.
845    old_part_size = test_utils.MiB(4)
846    new_part_size = test_utils.MiB(8)
847    old_block_counters = array.array(
848        'B', [0] * ((old_part_size + block_size - 1) / block_size))
849    new_block_counters = array.array(
850        'B', [0] * ((new_part_size + block_size - 1) / block_size))
851    prev_data_offset = 1876
852    blob_hash_counts = collections.defaultdict(int)
853
854    # Create the operation object for the test.
855    op = update_metadata_pb2.InstallOperation()
856    op.type = op_type
857
858    total_src_blocks = 0
859    if op_type in (common.OpType.MOVE, common.OpType.BSDIFF,
860                   common.OpType.SOURCE_COPY, common.OpType.SOURCE_BSDIFF):
861      if fail_src_extents:
862        self.AddToMessage(op.src_extents,
863                          self.NewExtentList((1, 0)))
864      else:
865        self.AddToMessage(op.src_extents,
866                          self.NewExtentList((1, 16)))
867        total_src_blocks = 16
868
869    if op_type in (common.OpType.REPLACE, common.OpType.REPLACE_BZ):
870      payload_checker.minor_version = 0
871    elif op_type in (common.OpType.MOVE, common.OpType.BSDIFF):
872      payload_checker.minor_version = 2 if fail_bad_minor_version else 1
873    elif op_type in (common.OpType.SOURCE_COPY, common.OpType.SOURCE_BSDIFF):
874      payload_checker.minor_version = 1 if fail_bad_minor_version else 2
875
876    if op_type not in (common.OpType.MOVE, common.OpType.SOURCE_COPY):
877      if not fail_mismatched_data_offset_length:
878        op.data_length = 16 * block_size - 8
879      if fail_prev_data_offset:
880        op.data_offset = prev_data_offset + 16
881      else:
882        op.data_offset = prev_data_offset
883
884      fake_data = 'fake-data'.ljust(op.data_length)
885      if not (allow_unhashed or (is_last and allow_signature and
886                                 op_type == common.OpType.REPLACE)):
887        if not fail_data_hash:
888          # Create a valid data blob hash.
889          op.data_sha256_hash = hashlib.sha256(fake_data).digest()
890          payload.ReadDataBlob(op.data_offset, op.data_length).AndReturn(
891              fake_data)
892      elif fail_data_hash:
893        # Create an invalid data blob hash.
894        op.data_sha256_hash = hashlib.sha256(
895            fake_data.replace(' ', '-')).digest()
896        payload.ReadDataBlob(op.data_offset, op.data_length).AndReturn(
897            fake_data)
898
899    total_dst_blocks = 0
900    if not fail_missing_dst_extents:
901      total_dst_blocks = 16
902      if fail_dst_extents:
903        self.AddToMessage(op.dst_extents,
904                          self.NewExtentList((4, 16), (32, 0)))
905      else:
906        self.AddToMessage(op.dst_extents,
907                          self.NewExtentList((4, 8), (64, 8)))
908
909    if total_src_blocks:
910      if fail_src_length:
911        op.src_length = total_src_blocks * block_size + 8
912      else:
913        op.src_length = total_src_blocks * block_size
914    elif fail_src_length:
915      # Add an orphaned src_length.
916      op.src_length = 16
917
918    if total_dst_blocks:
919      if fail_dst_length:
920        op.dst_length = total_dst_blocks * block_size + 8
921      else:
922        op.dst_length = total_dst_blocks * block_size
923
924    self.mox.ReplayAll()
925    should_fail = (fail_src_extents or fail_dst_extents or
926                   fail_mismatched_data_offset_length or
927                   fail_missing_dst_extents or fail_src_length or
928                   fail_dst_length or fail_data_hash or fail_prev_data_offset or
929                   fail_bad_minor_version)
930    args = (op, 'foo', is_last, old_block_counters, new_block_counters,
931            old_part_size, new_part_size, prev_data_offset, allow_signature,
932            blob_hash_counts)
933    if should_fail:
934      self.assertRaises(update_payload.PayloadError,
935                        payload_checker._CheckOperation, *args)
936    else:
937      self.assertEqual(op.data_length if op.HasField('data_length') else 0,
938                       payload_checker._CheckOperation(*args))
939
940  def testAllocBlockCounters(self):
941    """Tests _CheckMoveOperation()."""
942    payload_checker = checker.PayloadChecker(self.MockPayload())
943    block_size = payload_checker.block_size
944
945    # Check allocation for block-aligned partition size, ensure it's integers.
946    result = payload_checker._AllocBlockCounters(16 * block_size)
947    self.assertEqual(16, len(result))
948    self.assertEqual(int, type(result[0]))
949
950    # Check allocation of unaligned partition sizes.
951    result = payload_checker._AllocBlockCounters(16 * block_size - 1)
952    self.assertEqual(16, len(result))
953    result = payload_checker._AllocBlockCounters(16 * block_size + 1)
954    self.assertEqual(17, len(result))
955
956  def DoCheckOperationsTest(self, fail_nonexhaustive_full_update):
957    # Generate a test payload. For this test, we only care about one
958    # (arbitrary) set of operations, so we'll only be generating kernel and
959    # test with them.
960    payload_gen = test_utils.PayloadGenerator()
961
962    block_size = test_utils.KiB(4)
963    payload_gen.SetBlockSize(block_size)
964
965    rootfs_part_size = test_utils.MiB(8)
966
967    # Fake rootfs operations in a full update, tampered with as required.
968    rootfs_op_type = common.OpType.REPLACE
969    rootfs_data_length = rootfs_part_size
970    if fail_nonexhaustive_full_update:
971      rootfs_data_length -= block_size
972
973    payload_gen.AddOperation(False, rootfs_op_type,
974                             dst_extents=[(0, rootfs_data_length / block_size)],
975                             data_offset=0,
976                             data_length=rootfs_data_length)
977
978    # Create the test object.
979    payload_checker = _GetPayloadChecker(payload_gen.WriteToFile,
980                                         checker_init_dargs={
981                                             'allow_unhashed': True})
982    payload_checker.payload_type = checker._TYPE_FULL
983    report = checker._PayloadReport()
984
985    args = (payload_checker.payload.manifest.install_operations, report,
986            'foo', 0, rootfs_part_size, rootfs_part_size, 0, False)
987    if fail_nonexhaustive_full_update:
988      self.assertRaises(update_payload.PayloadError,
989                        payload_checker._CheckOperations, *args)
990    else:
991      self.assertEqual(rootfs_data_length,
992                       payload_checker._CheckOperations(*args))
993
994  def DoCheckSignaturesTest(self, fail_empty_sigs_blob, fail_missing_pseudo_op,
995                            fail_mismatched_pseudo_op, fail_sig_missing_fields,
996                            fail_unknown_sig_version, fail_incorrect_sig):
997    # Generate a test payload. For this test, we only care about the signature
998    # block and how it relates to the payload hash. Therefore, we're generating
999    # a random (otherwise useless) payload for this purpose.
1000    payload_gen = test_utils.EnhancedPayloadGenerator()
1001    block_size = test_utils.KiB(4)
1002    payload_gen.SetBlockSize(block_size)
1003    rootfs_part_size = test_utils.MiB(2)
1004    kernel_part_size = test_utils.KiB(16)
1005    payload_gen.SetPartInfo(False, True, rootfs_part_size,
1006                            hashlib.sha256('fake-new-rootfs-content').digest())
1007    payload_gen.SetPartInfo(True, True, kernel_part_size,
1008                            hashlib.sha256('fake-new-kernel-content').digest())
1009    payload_gen.SetMinorVersion(0)
1010    payload_gen.AddOperationWithData(
1011        False, common.OpType.REPLACE,
1012        dst_extents=[(0, rootfs_part_size / block_size)],
1013        data_blob=os.urandom(rootfs_part_size))
1014
1015    do_forge_pseudo_op = (fail_missing_pseudo_op or fail_mismatched_pseudo_op)
1016    do_forge_sigs_data = (do_forge_pseudo_op or fail_empty_sigs_blob or
1017                          fail_sig_missing_fields or fail_unknown_sig_version
1018                          or fail_incorrect_sig)
1019
1020    sigs_data = None
1021    if do_forge_sigs_data:
1022      sigs_gen = test_utils.SignaturesGenerator()
1023      if not fail_empty_sigs_blob:
1024        if fail_sig_missing_fields:
1025          sig_data = None
1026        else:
1027          sig_data = test_utils.SignSha256('fake-payload-content',
1028                                           test_utils._PRIVKEY_FILE_NAME)
1029        sigs_gen.AddSig(5 if fail_unknown_sig_version else 1, sig_data)
1030
1031      sigs_data = sigs_gen.ToBinary()
1032      payload_gen.SetSignatures(payload_gen.curr_offset, len(sigs_data))
1033
1034    if do_forge_pseudo_op:
1035      assert sigs_data is not None, 'should have forged signatures blob by now'
1036      sigs_len = len(sigs_data)
1037      payload_gen.AddOperation(
1038          False, common.OpType.REPLACE,
1039          data_offset=payload_gen.curr_offset / 2,
1040          data_length=sigs_len / 2,
1041          dst_extents=[(0, (sigs_len / 2 + block_size - 1) / block_size)])
1042
1043    # Generate payload (complete w/ signature) and create the test object.
1044    payload_checker = _GetPayloadChecker(
1045        payload_gen.WriteToFileWithData,
1046        payload_gen_dargs={
1047            'sigs_data': sigs_data,
1048            'privkey_file_name': test_utils._PRIVKEY_FILE_NAME,
1049            'do_add_pseudo_operation': not do_forge_pseudo_op})
1050    payload_checker.payload_type = checker._TYPE_FULL
1051    report = checker._PayloadReport()
1052
1053    # We have to check the manifest first in order to set signature attributes.
1054    payload_checker._CheckManifest(report, rootfs_part_size, kernel_part_size)
1055
1056    should_fail = (fail_empty_sigs_blob or fail_missing_pseudo_op or
1057                   fail_mismatched_pseudo_op or fail_sig_missing_fields or
1058                   fail_unknown_sig_version or fail_incorrect_sig)
1059    args = (report, test_utils._PUBKEY_FILE_NAME)
1060    if should_fail:
1061      self.assertRaises(update_payload.PayloadError,
1062                        payload_checker._CheckSignatures, *args)
1063    else:
1064      self.assertIsNone(payload_checker._CheckSignatures(*args))
1065
1066  def DoCheckManifestMinorVersionTest(self, minor_version, payload_type):
1067    """Parametric testing for CheckManifestMinorVersion().
1068
1069    Args:
1070      minor_version: The payload minor version to test with.
1071      payload_type: The type of the payload we're testing, delta or full.
1072    """
1073    # Create the test object.
1074    payload = self.MockPayload()
1075    payload.manifest.minor_version = minor_version
1076    payload_checker = checker.PayloadChecker(payload)
1077    payload_checker.payload_type = payload_type
1078    report = checker._PayloadReport()
1079
1080    should_succeed = (
1081        (minor_version == 0 and payload_type == checker._TYPE_FULL) or
1082        (minor_version == 1 and payload_type == checker._TYPE_DELTA) or
1083        (minor_version == 2 and payload_type == checker._TYPE_DELTA) or
1084        (minor_version == 3 and payload_type == checker._TYPE_DELTA) or
1085        (minor_version == 4 and payload_type == checker._TYPE_DELTA))
1086    args = (report,)
1087
1088    if should_succeed:
1089      self.assertIsNone(payload_checker._CheckManifestMinorVersion(*args))
1090    else:
1091      self.assertRaises(update_payload.PayloadError,
1092                        payload_checker._CheckManifestMinorVersion, *args)
1093
1094  def DoRunTest(self, rootfs_part_size_provided, kernel_part_size_provided,
1095                fail_wrong_payload_type, fail_invalid_block_size,
1096                fail_mismatched_block_size, fail_excess_data,
1097                fail_rootfs_part_size_exceeded,
1098                fail_kernel_part_size_exceeded):
1099    # Generate a test payload. For this test, we generate a full update that
1100    # has sample kernel and rootfs operations. Since most testing is done with
1101    # internal PayloadChecker methods that are tested elsewhere, here we only
1102    # tamper with what's actually being manipulated and/or tested in the Run()
1103    # method itself. Note that the checker doesn't verify partition hashes, so
1104    # they're safe to fake.
1105    payload_gen = test_utils.EnhancedPayloadGenerator()
1106    block_size = test_utils.KiB(4)
1107    payload_gen.SetBlockSize(block_size)
1108    kernel_filesystem_size = test_utils.KiB(16)
1109    rootfs_filesystem_size = test_utils.MiB(2)
1110    payload_gen.SetPartInfo(False, True, rootfs_filesystem_size,
1111                            hashlib.sha256('fake-new-rootfs-content').digest())
1112    payload_gen.SetPartInfo(True, True, kernel_filesystem_size,
1113                            hashlib.sha256('fake-new-kernel-content').digest())
1114    payload_gen.SetMinorVersion(0)
1115
1116    rootfs_part_size = 0
1117    if rootfs_part_size_provided:
1118      rootfs_part_size = rootfs_filesystem_size + block_size
1119    rootfs_op_size = rootfs_part_size or rootfs_filesystem_size
1120    if fail_rootfs_part_size_exceeded:
1121      rootfs_op_size += block_size
1122    payload_gen.AddOperationWithData(
1123        False, common.OpType.REPLACE,
1124        dst_extents=[(0, rootfs_op_size / block_size)],
1125        data_blob=os.urandom(rootfs_op_size))
1126
1127    kernel_part_size = 0
1128    if kernel_part_size_provided:
1129      kernel_part_size = kernel_filesystem_size + block_size
1130    kernel_op_size = kernel_part_size or kernel_filesystem_size
1131    if fail_kernel_part_size_exceeded:
1132      kernel_op_size += block_size
1133    payload_gen.AddOperationWithData(
1134        True, common.OpType.REPLACE,
1135        dst_extents=[(0, kernel_op_size / block_size)],
1136        data_blob=os.urandom(kernel_op_size))
1137
1138    # Generate payload (complete w/ signature) and create the test object.
1139    if fail_invalid_block_size:
1140      use_block_size = block_size + 5  # Not a power of two.
1141    elif fail_mismatched_block_size:
1142      use_block_size = block_size * 2  # Different that payload stated.
1143    else:
1144      use_block_size = block_size
1145
1146    kwargs = {
1147        'payload_gen_dargs': {
1148            'privkey_file_name': test_utils._PRIVKEY_FILE_NAME,
1149            'do_add_pseudo_operation': True,
1150            'is_pseudo_in_kernel': True,
1151            'padding': os.urandom(1024) if fail_excess_data else None},
1152        'checker_init_dargs': {
1153            'assert_type': 'delta' if fail_wrong_payload_type else 'full',
1154            'block_size': use_block_size}}
1155    if fail_invalid_block_size:
1156      self.assertRaises(update_payload.PayloadError, _GetPayloadChecker,
1157                        payload_gen.WriteToFileWithData, **kwargs)
1158    else:
1159      payload_checker = _GetPayloadChecker(payload_gen.WriteToFileWithData,
1160                                           **kwargs)
1161
1162      kwargs = {'pubkey_file_name': test_utils._PUBKEY_FILE_NAME,
1163                'rootfs_part_size': rootfs_part_size,
1164                'kernel_part_size': kernel_part_size}
1165      should_fail = (fail_wrong_payload_type or fail_mismatched_block_size or
1166                     fail_excess_data or
1167                     fail_rootfs_part_size_exceeded or
1168                     fail_kernel_part_size_exceeded)
1169      if should_fail:
1170        self.assertRaises(update_payload.PayloadError, payload_checker.Run,
1171                          **kwargs)
1172      else:
1173        self.assertIsNone(payload_checker.Run(**kwargs))
1174
1175# This implements a generic API, hence the occasional unused args.
1176# pylint: disable=W0613
1177def ValidateCheckOperationTest(op_type_name, is_last, allow_signature,
1178                               allow_unhashed, fail_src_extents,
1179                               fail_dst_extents,
1180                               fail_mismatched_data_offset_length,
1181                               fail_missing_dst_extents, fail_src_length,
1182                               fail_dst_length, fail_data_hash,
1183                               fail_prev_data_offset, fail_bad_minor_version):
1184  """Returns True iff the combination of arguments represents a valid test."""
1185  op_type = _OpTypeByName(op_type_name)
1186
1187  # REPLACE/REPLACE_BZ operations don't read data from src partition. They are
1188  # compatible with all valid minor versions, so we don't need to check that.
1189  if (op_type in (common.OpType.REPLACE, common.OpType.REPLACE_BZ) and (
1190      fail_src_extents or fail_src_length or fail_bad_minor_version)):
1191    return False
1192
1193  # MOVE and SOURCE_COPY operations don't carry data.
1194  if (op_type in (common.OpType.MOVE, common.OpType.SOURCE_COPY) and (
1195      fail_mismatched_data_offset_length or fail_data_hash or
1196      fail_prev_data_offset)):
1197    return False
1198
1199  return True
1200
1201
1202def TestMethodBody(run_method_name, run_dargs):
1203  """Returns a function that invokes a named method with named arguments."""
1204  return lambda self: getattr(self, run_method_name)(**run_dargs)
1205
1206
1207def AddParametricTests(tested_method_name, arg_space, validate_func=None):
1208  """Enumerates and adds specific parametric tests to PayloadCheckerTest.
1209
1210  This function enumerates a space of test parameters (defined by arg_space),
1211  then binds a new, unique method name in PayloadCheckerTest to a test function
1212  that gets handed the said parameters. This is a preferable approach to doing
1213  the enumeration and invocation during the tests because this way each test is
1214  treated as a complete run by the unittest framework, and so benefits from the
1215  usual setUp/tearDown mechanics.
1216
1217  Args:
1218    tested_method_name: Name of the tested PayloadChecker method.
1219    arg_space: A dictionary containing variables (keys) and lists of values
1220               (values) associated with them.
1221    validate_func: A function used for validating test argument combinations.
1222  """
1223  for value_tuple in itertools.product(*arg_space.itervalues()):
1224    run_dargs = dict(zip(arg_space.iterkeys(), value_tuple))
1225    if validate_func and not validate_func(**run_dargs):
1226      continue
1227    run_method_name = 'Do%sTest' % tested_method_name
1228    test_method_name = 'test%s' % tested_method_name
1229    for arg_key, arg_val in run_dargs.iteritems():
1230      if arg_val or type(arg_val) is int:
1231        test_method_name += '__%s=%s' % (arg_key, arg_val)
1232    setattr(PayloadCheckerTest, test_method_name,
1233            TestMethodBody(run_method_name, run_dargs))
1234
1235
1236def AddAllParametricTests():
1237  """Enumerates and adds all parametric tests to PayloadCheckerTest."""
1238  # Add all _CheckElem() test cases.
1239  AddParametricTests('AddElem',
1240                     {'linebreak': (True, False),
1241                      'indent': (0, 1, 2),
1242                      'convert': (str, lambda s: s[::-1]),
1243                      'is_present': (True, False),
1244                      'is_mandatory': (True, False),
1245                      'is_submsg': (True, False)})
1246
1247  # Add all _Add{Mandatory,Optional}Field tests.
1248  AddParametricTests('AddField',
1249                     {'is_mandatory': (True, False),
1250                      'linebreak': (True, False),
1251                      'indent': (0, 1, 2),
1252                      'convert': (str, lambda s: s[::-1]),
1253                      'is_present': (True, False)})
1254
1255  # Add all _Add{Mandatory,Optional}SubMsg tests.
1256  AddParametricTests('AddSubMsg',
1257                     {'is_mandatory': (True, False),
1258                      'is_present': (True, False)})
1259
1260  # Add all _CheckManifest() test cases.
1261  AddParametricTests('CheckManifest',
1262                     {'fail_mismatched_block_size': (True, False),
1263                      'fail_bad_sigs': (True, False),
1264                      'fail_mismatched_oki_ori': (True, False),
1265                      'fail_bad_oki': (True, False),
1266                      'fail_bad_ori': (True, False),
1267                      'fail_bad_nki': (True, False),
1268                      'fail_bad_nri': (True, False),
1269                      'fail_old_kernel_fs_size': (True, False),
1270                      'fail_old_rootfs_fs_size': (True, False),
1271                      'fail_new_kernel_fs_size': (True, False),
1272                      'fail_new_rootfs_fs_size': (True, False)})
1273
1274  # Add all _CheckOperation() test cases.
1275  AddParametricTests('CheckOperation',
1276                     {'op_type_name': ('REPLACE', 'REPLACE_BZ', 'MOVE',
1277                                       'BSDIFF', 'SOURCE_COPY',
1278                                       'SOURCE_BSDIFF'),
1279                      'is_last': (True, False),
1280                      'allow_signature': (True, False),
1281                      'allow_unhashed': (True, False),
1282                      'fail_src_extents': (True, False),
1283                      'fail_dst_extents': (True, False),
1284                      'fail_mismatched_data_offset_length': (True, False),
1285                      'fail_missing_dst_extents': (True, False),
1286                      'fail_src_length': (True, False),
1287                      'fail_dst_length': (True, False),
1288                      'fail_data_hash': (True, False),
1289                      'fail_prev_data_offset': (True, False),
1290                      'fail_bad_minor_version': (True, False)},
1291                     validate_func=ValidateCheckOperationTest)
1292
1293  # Add all _CheckOperations() test cases.
1294  AddParametricTests('CheckOperations',
1295                     {'fail_nonexhaustive_full_update': (True, False)})
1296
1297  # Add all _CheckOperations() test cases.
1298  AddParametricTests('CheckSignatures',
1299                     {'fail_empty_sigs_blob': (True, False),
1300                      'fail_missing_pseudo_op': (True, False),
1301                      'fail_mismatched_pseudo_op': (True, False),
1302                      'fail_sig_missing_fields': (True, False),
1303                      'fail_unknown_sig_version': (True, False),
1304                      'fail_incorrect_sig': (True, False)})
1305
1306  # Add all _CheckManifestMinorVersion() test cases.
1307  AddParametricTests('CheckManifestMinorVersion',
1308                     {'minor_version': (None, 0, 1, 2, 3, 4, 555),
1309                      'payload_type': (checker._TYPE_FULL,
1310                                       checker._TYPE_DELTA)})
1311
1312  # Add all Run() test cases.
1313  AddParametricTests('Run',
1314                     {'rootfs_part_size_provided': (True, False),
1315                      'kernel_part_size_provided': (True, False),
1316                      'fail_wrong_payload_type': (True, False),
1317                      'fail_invalid_block_size': (True, False),
1318                      'fail_mismatched_block_size': (True, False),
1319                      'fail_excess_data': (True, False),
1320                      'fail_rootfs_part_size_exceeded': (True, False),
1321                      'fail_kernel_part_size_exceeded': (True, False)})
1322
1323
1324if __name__ == '__main__':
1325  AddAllParametricTests()
1326  unittest.main()
1327