1import os 2import unittest 3import collections 4import email 5from email.message import Message 6from email._policybase import compat32 7from test.support import load_package_tests 8from test.test_email import __file__ as landmark 9 10# Load all tests in package 11def load_tests(*args): 12 return load_package_tests(os.path.dirname(__file__), *args) 13 14 15# helper code used by a number of test modules. 16 17def openfile(filename, *args, **kws): 18 path = os.path.join(os.path.dirname(landmark), 'data', filename) 19 return open(path, *args, **kws) 20 21 22# Base test class 23class TestEmailBase(unittest.TestCase): 24 25 maxDiff = None 26 # Currently the default policy is compat32. By setting that as the default 27 # here we make minimal changes in the test_email tests compared to their 28 # pre-3.3 state. 29 policy = compat32 30 # Likewise, the default message object is Message. 31 message = Message 32 33 def __init__(self, *args, **kw): 34 super().__init__(*args, **kw) 35 self.addTypeEqualityFunc(bytes, self.assertBytesEqual) 36 37 # Backward compatibility to minimize test_email test changes. 38 ndiffAssertEqual = unittest.TestCase.assertEqual 39 40 def _msgobj(self, filename): 41 with openfile(filename, encoding="utf-8") as fp: 42 return email.message_from_file(fp, policy=self.policy) 43 44 def _str_msg(self, string, message=None, policy=None): 45 if policy is None: 46 policy = self.policy 47 if message is None: 48 message = self.message 49 return email.message_from_string(string, message, policy=policy) 50 51 def _bytes_msg(self, bytestring, message=None, policy=None): 52 if policy is None: 53 policy = self.policy 54 if message is None: 55 message = self.message 56 return email.message_from_bytes(bytestring, message, policy=policy) 57 58 def _make_message(self): 59 return self.message(policy=self.policy) 60 61 def _bytes_repr(self, b): 62 return [repr(x) for x in b.splitlines(keepends=True)] 63 64 def assertBytesEqual(self, first, second, msg): 65 """Our byte strings are really encoded strings; improve diff output""" 66 self.assertEqual(self._bytes_repr(first), self._bytes_repr(second)) 67 68 def assertDefectsEqual(self, actual, expected): 69 self.assertEqual(len(actual), len(expected), actual) 70 for i in range(len(actual)): 71 self.assertIsInstance(actual[i], expected[i], 72 'item {}'.format(i)) 73 74 75def parameterize(cls): 76 """A test method parameterization class decorator. 77 78 Parameters are specified as the value of a class attribute that ends with 79 the string '_params'. Call the portion before '_params' the prefix. Then 80 a method to be parameterized must have the same prefix, the string 81 '_as_', and an arbitrary suffix. 82 83 The value of the _params attribute may be either a dictionary or a list. 84 The values in the dictionary and the elements of the list may either be 85 single values, or a list. If single values, they are turned into single 86 element tuples. However derived, the resulting sequence is passed via 87 *args to the parameterized test function. 88 89 In a _params dictionary, the keys become part of the name of the generated 90 tests. In a _params list, the values in the list are converted into a 91 string by joining the string values of the elements of the tuple by '_' and 92 converting any blanks into '_'s, and this become part of the name. 93 The full name of a generated test is a 'test_' prefix, the portion of the 94 test function name after the '_as_' separator, plus an '_', plus the name 95 derived as explained above. 96 97 For example, if we have: 98 99 count_params = range(2) 100 101 def count_as_foo_arg(self, foo): 102 self.assertEqual(foo+1, myfunc(foo)) 103 104 we will get parameterized test methods named: 105 test_foo_arg_0 106 test_foo_arg_1 107 test_foo_arg_2 108 109 Or we could have: 110 111 example_params = {'foo': ('bar', 1), 'bing': ('bang', 2)} 112 113 def example_as_myfunc_input(self, name, count): 114 self.assertEqual(name+str(count), myfunc(name, count)) 115 116 and get: 117 test_myfunc_input_foo 118 test_myfunc_input_bing 119 120 Note: if and only if the generated test name is a valid identifier can it 121 be used to select the test individually from the unittest command line. 122 123 The values in the params dict can be a single value, a tuple, or a 124 dict. If a single value of a tuple, it is passed to the test function 125 as positional arguments. If a dict, it is a passed via **kw. 126 127 """ 128 paramdicts = {} 129 testers = collections.defaultdict(list) 130 for name, attr in cls.__dict__.items(): 131 if name.endswith('_params'): 132 if not hasattr(attr, 'keys'): 133 d = {} 134 for x in attr: 135 if not hasattr(x, '__iter__'): 136 x = (x,) 137 n = '_'.join(str(v) for v in x).replace(' ', '_') 138 d[n] = x 139 attr = d 140 paramdicts[name[:-7] + '_as_'] = attr 141 if '_as_' in name: 142 testers[name.split('_as_')[0] + '_as_'].append(name) 143 testfuncs = {} 144 for name in paramdicts: 145 if name not in testers: 146 raise ValueError("No tester found for {}".format(name)) 147 for name in testers: 148 if name not in paramdicts: 149 raise ValueError("No params found for {}".format(name)) 150 for name, attr in cls.__dict__.items(): 151 for paramsname, paramsdict in paramdicts.items(): 152 if name.startswith(paramsname): 153 testnameroot = 'test_' + name[len(paramsname):] 154 for paramname, params in paramsdict.items(): 155 if hasattr(params, 'keys'): 156 test = (lambda self, name=name, params=params: 157 getattr(self, name)(**params)) 158 else: 159 test = (lambda self, name=name, params=params: 160 getattr(self, name)(*params)) 161 testname = testnameroot + '_' + paramname 162 test.__name__ = testname 163 testfuncs[testname] = test 164 for key, value in testfuncs.items(): 165 setattr(cls, key, value) 166 return cls 167