• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1__author__ = "raphtee@google.com (Travis Miller)"
2
3
4import re, collections, StringIO, sys, unittest
5
6
7class StubNotFoundError(Exception):
8    'Raised when god is asked to unstub an attribute that was not stubbed'
9    pass
10
11
12class CheckPlaybackError(Exception):
13    'Raised when mock playback does not match recorded calls.'
14    pass
15
16
17class SaveDataAfterCloseStringIO(StringIO.StringIO):
18    """Saves the contents in a final_data property when close() is called.
19
20    Useful as a mock output file object to test both that the file was
21    closed and what was written.
22
23    Properties:
24      final_data: Set to the StringIO's getvalue() data when close() is
25          called.  None if close() has not been called.
26    """
27    final_data = None
28
29    def close(self):
30        self.final_data = self.getvalue()
31        StringIO.StringIO.close(self)
32
33
34
35class argument_comparator(object):
36    def is_satisfied_by(self, parameter):
37        raise NotImplementedError
38
39
40class equality_comparator(argument_comparator):
41    def __init__(self, value):
42        self.value = value
43
44
45    @staticmethod
46    def _types_match(arg1, arg2):
47        if isinstance(arg1, basestring) and isinstance(arg2, basestring):
48            return True
49        return type(arg1) == type(arg2)
50
51
52    @classmethod
53    def _compare(cls, actual_arg, expected_arg):
54        if isinstance(expected_arg, argument_comparator):
55            return expected_arg.is_satisfied_by(actual_arg)
56        if not cls._types_match(expected_arg, actual_arg):
57            return False
58
59        if isinstance(expected_arg, list) or isinstance(expected_arg, tuple):
60            # recurse on lists/tuples
61            if len(actual_arg) != len(expected_arg):
62                return False
63            for actual_item, expected_item in zip(actual_arg, expected_arg):
64                if not cls._compare(actual_item, expected_item):
65                    return False
66        elif isinstance(expected_arg, dict):
67            # recurse on dicts
68            if not cls._compare(sorted(actual_arg.keys()),
69                                sorted(expected_arg.keys())):
70                return False
71            for key, value in actual_arg.iteritems():
72                if not cls._compare(value, expected_arg[key]):
73                    return False
74        elif actual_arg != expected_arg:
75            return False
76
77        return True
78
79
80    def is_satisfied_by(self, parameter):
81        return self._compare(parameter, self.value)
82
83
84    def __str__(self):
85        if isinstance(self.value, argument_comparator):
86            return str(self.value)
87        return repr(self.value)
88
89
90class regex_comparator(argument_comparator):
91    def __init__(self, pattern, flags=0):
92        self.regex = re.compile(pattern, flags)
93
94
95    def is_satisfied_by(self, parameter):
96        return self.regex.search(parameter) is not None
97
98
99    def __str__(self):
100        return self.regex.pattern
101
102
103class is_string_comparator(argument_comparator):
104    def is_satisfied_by(self, parameter):
105        return isinstance(parameter, basestring)
106
107
108    def __str__(self):
109        return "a string"
110
111
112class is_instance_comparator(argument_comparator):
113    def __init__(self, cls):
114        self.cls = cls
115
116
117    def is_satisfied_by(self, parameter):
118        return isinstance(parameter, self.cls)
119
120
121    def __str__(self):
122        return "is a %s" % self.cls
123
124
125class anything_comparator(argument_comparator):
126    def is_satisfied_by(self, parameter):
127        return True
128
129
130    def __str__(self):
131        return 'anything'
132
133
134class base_mapping(object):
135    def __init__(self, symbol, return_obj, *args, **dargs):
136        self.return_obj = return_obj
137        self.symbol = symbol
138        self.args = [equality_comparator(arg) for arg in args]
139        self.dargs = dict((key, equality_comparator(value))
140                          for key, value in dargs.iteritems())
141        self.error = None
142
143
144    def match(self, *args, **dargs):
145        if len(args) != len(self.args) or len(dargs) != len(self.dargs):
146            return False
147
148        for i, expected_arg in enumerate(self.args):
149            if not expected_arg.is_satisfied_by(args[i]):
150                return False
151
152        # check for incorrect dargs
153        for key, value in dargs.iteritems():
154            if key not in self.dargs:
155                return False
156            if not self.dargs[key].is_satisfied_by(value):
157                return False
158
159        # check for missing dargs
160        for key in self.dargs.iterkeys():
161            if key not in dargs:
162                return False
163
164        return True
165
166
167    def __str__(self):
168        return _dump_function_call(self.symbol, self.args, self.dargs)
169
170
171class function_mapping(base_mapping):
172    def __init__(self, symbol, return_val, *args, **dargs):
173        super(function_mapping, self).__init__(symbol, return_val, *args,
174                                               **dargs)
175
176
177    def and_return(self, return_obj):
178        self.return_obj = return_obj
179
180
181    def and_raises(self, error):
182        self.error = error
183
184
185class function_any_args_mapping(function_mapping):
186    """A mock function mapping that doesn't verify its arguments."""
187    def match(self, *args, **dargs):
188        return True
189
190
191class mock_function(object):
192    def __init__(self, symbol, default_return_val=None,
193                 record=None, playback=None):
194        self.default_return_val = default_return_val
195        self.num_calls = 0
196        self.args = []
197        self.dargs = []
198        self.symbol = symbol
199        self.record = record
200        self.playback = playback
201        self.__name__ = symbol
202
203
204    def __call__(self, *args, **dargs):
205        self.num_calls += 1
206        self.args.append(args)
207        self.dargs.append(dargs)
208        if self.playback:
209            return self.playback(self.symbol, *args, **dargs)
210        else:
211            return self.default_return_val
212
213
214    def expect_call(self, *args, **dargs):
215        mapping = function_mapping(self.symbol, None, *args, **dargs)
216        if self.record:
217            self.record(mapping)
218
219        return mapping
220
221
222    def expect_any_call(self):
223        """Like expect_call but don't give a hoot what arguments are passed."""
224        mapping = function_any_args_mapping(self.symbol, None)
225        if self.record:
226            self.record(mapping)
227
228        return mapping
229
230
231class mask_function(mock_function):
232    def __init__(self, symbol, original_function, default_return_val=None,
233                 record=None, playback=None):
234        super(mask_function, self).__init__(symbol,
235                                            default_return_val,
236                                            record, playback)
237        self.original_function = original_function
238
239
240    def run_original_function(self, *args, **dargs):
241        return self.original_function(*args, **dargs)
242
243
244class mock_class(object):
245    def __init__(self, cls, name, default_ret_val=None,
246                 record=None, playback=None):
247        self.__name = name
248        self.__record = record
249        self.__playback = playback
250
251        for symbol in dir(cls):
252            if symbol.startswith("_"):
253                continue
254
255            orig_symbol = getattr(cls, symbol)
256            if callable(orig_symbol):
257                f_name = "%s.%s" % (self.__name, symbol)
258                func = mock_function(f_name, default_ret_val,
259                                     self.__record, self.__playback)
260                setattr(self, symbol, func)
261            else:
262                setattr(self, symbol, orig_symbol)
263
264
265    def __repr__(self):
266        return '<mock_class: %s>' % self.__name
267
268
269class mock_god(object):
270    NONEXISTENT_ATTRIBUTE = object()
271
272    def __init__(self, debug=False, fail_fast=True, ut=None):
273        """
274        With debug=True, all recorded method calls will be printed as
275        they happen.
276        With fail_fast=True, unexpected calls will immediately cause an
277        exception to be raised.  With False, they will be silently recorded and
278        only reported when check_playback() is called.
279        """
280        self.recording = collections.deque()
281        self.errors = []
282        self._stubs = []
283        self._debug = debug
284        self._fail_fast = fail_fast
285        self._ut = ut
286
287
288    def set_fail_fast(self, fail_fast):
289        self._fail_fast = fail_fast
290
291
292    def create_mock_class_obj(self, cls, name, default_ret_val=None):
293        record = self.__record_call
294        playback = self.__method_playback
295        errors = self.errors
296
297        class cls_sub(cls):
298            cls_count = 0
299
300            # overwrite the initializer
301            def __init__(self, *args, **dargs):
302                pass
303
304
305            @classmethod
306            def expect_new(typ, *args, **dargs):
307                obj = typ.make_new(*args, **dargs)
308                mapping = base_mapping(name, obj, *args, **dargs)
309                record(mapping)
310                return obj
311
312
313            def __new__(typ, *args, **dargs):
314                return playback(name, *args, **dargs)
315
316
317            @classmethod
318            def make_new(typ, *args, **dargs):
319                obj = super(cls_sub, typ).__new__(typ, *args,
320                                                  **dargs)
321
322                typ.cls_count += 1
323                obj_name = "%s_%s" % (name, typ.cls_count)
324                for symbol in dir(obj):
325                    if (symbol.startswith("__") and
326                        symbol.endswith("__")):
327                        continue
328
329                    if isinstance(getattr(typ, symbol, None), property):
330                        continue
331
332                    orig_symbol = getattr(obj, symbol)
333                    if callable(orig_symbol):
334                        f_name = ("%s.%s" %
335                                  (obj_name, symbol))
336                        func = mock_function(f_name,
337                                        default_ret_val,
338                                        record,
339                                        playback)
340                        setattr(obj, symbol, func)
341                    else:
342                        setattr(obj, symbol,
343                                orig_symbol)
344
345                return obj
346
347        return cls_sub
348
349
350    def create_mock_class(self, cls, name, default_ret_val=None):
351        """
352        Given something that defines a namespace cls (class, object,
353        module), and a (hopefully unique) name, will create a
354        mock_class object with that name and that possessess all
355        the public attributes of cls.  default_ret_val sets the
356        default_ret_val on all methods of the cls mock.
357        """
358        return mock_class(cls, name, default_ret_val,
359                          self.__record_call, self.__method_playback)
360
361
362    def create_mock_function(self, symbol, default_return_val=None):
363        """
364        create a mock_function with name symbol and default return
365        value of default_ret_val.
366        """
367        return mock_function(symbol, default_return_val,
368                             self.__record_call, self.__method_playback)
369
370
371    def mock_up(self, obj, name, default_ret_val=None):
372        """
373        Given an object (class instance or module) and a registration
374        name, then replace all its methods with mock function objects
375        (passing the orignal functions to the mock functions).
376        """
377        for symbol in dir(obj):
378            if symbol.startswith("__"):
379                continue
380
381            orig_symbol = getattr(obj, symbol)
382            if callable(orig_symbol):
383                f_name = "%s.%s" % (name, symbol)
384                func = mask_function(f_name, orig_symbol,
385                                     default_ret_val,
386                                     self.__record_call,
387                                     self.__method_playback)
388                setattr(obj, symbol, func)
389
390
391    def stub_with(self, namespace, symbol, new_attribute):
392        original_attribute = getattr(namespace, symbol,
393                                     self.NONEXISTENT_ATTRIBUTE)
394
395        # You only want to save the original attribute in cases where it is
396        # directly associated with the object in question. In cases where
397        # the attribute is actually inherited via some sort of hierarchy
398        # you want to delete the stub (restoring the original structure)
399        attribute_is_inherited = (hasattr(namespace, '__dict__') and
400                                  symbol not in namespace.__dict__)
401        if attribute_is_inherited:
402            original_attribute = self.NONEXISTENT_ATTRIBUTE
403
404        newstub = (namespace, symbol, original_attribute, new_attribute)
405        self._stubs.append(newstub)
406        setattr(namespace, symbol, new_attribute)
407
408
409    def stub_function(self, namespace, symbol):
410        mock_attribute = self.create_mock_function(symbol)
411        self.stub_with(namespace, symbol, mock_attribute)
412
413
414    def stub_class_method(self, cls, symbol):
415        mock_attribute = self.create_mock_function(symbol)
416        self.stub_with(cls, symbol, staticmethod(mock_attribute))
417
418
419    def stub_class(self, namespace, symbol):
420        attr = getattr(namespace, symbol)
421        mock_class = self.create_mock_class_obj(attr, symbol)
422        self.stub_with(namespace, symbol, mock_class)
423
424
425    def stub_function_to_return(self, namespace, symbol, object_to_return):
426        """Stub out a function with one that always returns a fixed value.
427
428        @param namespace The namespace containing the function to stub out.
429        @param symbol The attribute within the namespace to stub out.
430        @param object_to_return The value that the stub should return whenever
431            it is called.
432        """
433        self.stub_with(namespace, symbol,
434                       lambda *args, **dargs: object_to_return)
435
436
437    def _perform_unstub(self, stub):
438        namespace, symbol, orig_attr, new_attr = stub
439        if orig_attr == self.NONEXISTENT_ATTRIBUTE:
440            delattr(namespace, symbol)
441        else:
442            setattr(namespace, symbol, orig_attr)
443
444
445    def unstub(self, namespace, symbol):
446        for stub in reversed(self._stubs):
447            if (namespace, symbol) == (stub[0], stub[1]):
448                self._perform_unstub(stub)
449                self._stubs.remove(stub)
450                return
451
452        raise StubNotFoundError()
453
454
455    def unstub_all(self):
456        self._stubs.reverse()
457        for stub in self._stubs:
458            self._perform_unstub(stub)
459        self._stubs = []
460
461
462    def __method_playback(self, symbol, *args, **dargs):
463        if self._debug:
464            print >> sys.__stdout__, (' * Mock call: ' +
465                                      _dump_function_call(symbol, args, dargs))
466
467        if len(self.recording) != 0:
468            func_call = self.recording[0]
469            if func_call.symbol != symbol:
470                msg = ("Unexpected call: %s\nExpected: %s"
471                    % (_dump_function_call(symbol, args, dargs),
472                       func_call))
473                self._append_error(msg)
474                return None
475
476            if not func_call.match(*args, **dargs):
477                msg = ("Incorrect call: %s\nExpected: %s"
478                    % (_dump_function_call(symbol, args, dargs),
479                      func_call))
480                self._append_error(msg)
481                return None
482
483            # this is the expected call so pop it and return
484            self.recording.popleft()
485            if func_call.error:
486                raise func_call.error
487            else:
488                return func_call.return_obj
489        else:
490            msg = ("unexpected call: %s"
491                   % (_dump_function_call(symbol, args, dargs)))
492            self._append_error(msg)
493            return None
494
495
496    def __record_call(self, mapping):
497        self.recording.append(mapping)
498
499
500    def _append_error(self, error):
501        if self._debug:
502            print >> sys.__stdout__, ' *** ' + error
503        if self._fail_fast:
504            raise CheckPlaybackError(error)
505        self.errors.append(error)
506
507
508    def check_playback(self):
509        """
510        Report any errors that were encounterd during calls
511        to __method_playback().
512        """
513        if len(self.errors) > 0:
514            if self._debug:
515                print '\nPlayback errors:'
516            for error in self.errors:
517                print >> sys.__stdout__, error
518
519            if self._ut:
520                self._ut.fail('\n'.join(self.errors))
521
522            raise CheckPlaybackError
523        elif len(self.recording) != 0:
524            errors = []
525            for func_call in self.recording:
526                error = "%s not called" % (func_call,)
527                errors.append(error)
528                print >> sys.__stdout__, error
529
530            if self._ut:
531                self._ut.fail('\n'.join(errors))
532
533            raise CheckPlaybackError
534        self.recording.clear()
535
536
537    def mock_io(self):
538        """Mocks and saves the stdout & stderr output"""
539        self.orig_stdout = sys.stdout
540        self.orig_stderr = sys.stderr
541
542        self.mock_streams_stdout = StringIO.StringIO('')
543        self.mock_streams_stderr = StringIO.StringIO('')
544
545        sys.stdout = self.mock_streams_stdout
546        sys.stderr = self.mock_streams_stderr
547
548
549    def unmock_io(self):
550        """Restores the stdout & stderr, and returns both
551        output strings"""
552        sys.stdout = self.orig_stdout
553        sys.stderr = self.orig_stderr
554        values = (self.mock_streams_stdout.getvalue(),
555                  self.mock_streams_stderr.getvalue())
556
557        self.mock_streams_stdout.close()
558        self.mock_streams_stderr.close()
559        return values
560
561
562def _arg_to_str(arg):
563    if isinstance(arg, argument_comparator):
564        return str(arg)
565    return repr(arg)
566
567
568def _dump_function_call(symbol, args, dargs):
569    arg_vec = []
570    for arg in args:
571        arg_vec.append(_arg_to_str(arg))
572    for key, val in dargs.iteritems():
573        arg_vec.append("%s=%s" % (key, _arg_to_str(val)))
574    return "%s(%s)" % (symbol, ', '.join(arg_vec))
575