1#!/usr/bin/python2.4 2# 3# Copyright 2008 Google Inc. 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17# This file is used for testing. The original is at: 18# http://code.google.com/p/pymox/ 19 20"""Mox, an object-mocking framework for Python. 21 22Mox works in the record-replay-verify paradigm. When you first create 23a mock object, it is in record mode. You then programmatically set 24the expected behavior of the mock object (what methods are to be 25called on it, with what parameters, what they should return, and in 26what order). 27 28Once you have set up the expected mock behavior, you put it in replay 29mode. Now the mock responds to method calls just as you told it to. 30If an unexpected method (or an expected method with unexpected 31parameters) is called, then an exception will be raised. 32 33Once you are done interacting with the mock, you need to verify that 34all the expected interactions occurred. (Maybe your code exited 35prematurely without calling some cleanup method!) The verify phase 36ensures that every expected method was called; otherwise, an exception 37will be raised. 38 39Suggested usage / workflow: 40 41 # Create Mox factory 42 my_mox = Mox() 43 44 # Create a mock data access object 45 mock_dao = my_mox.CreateMock(DAOClass) 46 47 # Set up expected behavior 48 mock_dao.RetrievePersonWithIdentifier('1').AndReturn(person) 49 mock_dao.DeletePerson(person) 50 51 # Put mocks in replay mode 52 my_mox.ReplayAll() 53 54 # Inject mock object and run test 55 controller.SetDao(mock_dao) 56 controller.DeletePersonById('1') 57 58 # Verify all methods were called as expected 59 my_mox.VerifyAll() 60""" 61 62from collections import deque 63import re 64import types 65import unittest 66 67import stubout 68 69class Error(AssertionError): 70 """Base exception for this module.""" 71 72 pass 73 74 75class ExpectedMethodCallsError(Error): 76 """Raised when Verify() is called before all expected methods have been called 77 """ 78 79 def __init__(self, expected_methods): 80 """Init exception. 81 82 Args: 83 # expected_methods: A sequence of MockMethod objects that should have been 84 # called. 85 expected_methods: [MockMethod] 86 87 Raises: 88 ValueError: if expected_methods contains no methods. 89 """ 90 91 if not expected_methods: 92 raise ValueError("There must be at least one expected method") 93 Error.__init__(self) 94 self._expected_methods = expected_methods 95 96 def __str__(self): 97 calls = "\n".join(["%3d. %s" % (i, m) 98 for i, m in enumerate(self._expected_methods)]) 99 return "Verify: Expected methods never called:\n%s" % (calls,) 100 101 102class UnexpectedMethodCallError(Error): 103 """Raised when an unexpected method is called. 104 105 This can occur if a method is called with incorrect parameters, or out of the 106 specified order. 107 """ 108 109 def __init__(self, unexpected_method, expected): 110 """Init exception. 111 112 Args: 113 # unexpected_method: MockMethod that was called but was not at the head of 114 # the expected_method queue. 115 # expected: MockMethod or UnorderedGroup the method should have 116 # been in. 117 unexpected_method: MockMethod 118 expected: MockMethod or UnorderedGroup 119 """ 120 121 Error.__init__(self) 122 self._unexpected_method = unexpected_method 123 self._expected = expected 124 125 def __str__(self): 126 return "Unexpected method call: %s. Expecting: %s" % \ 127 (self._unexpected_method, self._expected) 128 129 130class UnknownMethodCallError(Error): 131 """Raised if an unknown method is requested of the mock object.""" 132 133 def __init__(self, unknown_method_name): 134 """Init exception. 135 136 Args: 137 # unknown_method_name: Method call that is not part of the mocked class's 138 # public interface. 139 unknown_method_name: str 140 """ 141 142 Error.__init__(self) 143 self._unknown_method_name = unknown_method_name 144 145 def __str__(self): 146 return "Method called is not a member of the object: %s" % \ 147 self._unknown_method_name 148 149 150class Mox(object): 151 """Mox: a factory for creating mock objects.""" 152 153 # A list of types that should be stubbed out with MockObjects (as 154 # opposed to MockAnythings). 155 _USE_MOCK_OBJECT = [types.ClassType, types.InstanceType, types.ModuleType, 156 types.ObjectType, types.TypeType] 157 158 def __init__(self): 159 """Initialize a new Mox.""" 160 161 self._mock_objects = [] 162 self.stubs = stubout.StubOutForTesting() 163 164 def CreateMock(self, class_to_mock): 165 """Create a new mock object. 166 167 Args: 168 # class_to_mock: the class to be mocked 169 class_to_mock: class 170 171 Returns: 172 MockObject that can be used as the class_to_mock would be. 173 """ 174 175 new_mock = MockObject(class_to_mock) 176 self._mock_objects.append(new_mock) 177 return new_mock 178 179 def CreateMockAnything(self): 180 """Create a mock that will accept any method calls. 181 182 This does not enforce an interface. 183 """ 184 185 new_mock = MockAnything() 186 self._mock_objects.append(new_mock) 187 return new_mock 188 189 def ReplayAll(self): 190 """Set all mock objects to replay mode.""" 191 192 for mock_obj in self._mock_objects: 193 mock_obj._Replay() 194 195 196 def VerifyAll(self): 197 """Call verify on all mock objects created.""" 198 199 for mock_obj in self._mock_objects: 200 mock_obj._Verify() 201 202 def ResetAll(self): 203 """Call reset on all mock objects. This does not unset stubs.""" 204 205 for mock_obj in self._mock_objects: 206 mock_obj._Reset() 207 208 def StubOutWithMock(self, obj, attr_name, use_mock_anything=False): 209 """Replace a method, attribute, etc. with a Mock. 210 211 This will replace a class or module with a MockObject, and everything else 212 (method, function, etc) with a MockAnything. This can be overridden to 213 always use a MockAnything by setting use_mock_anything to True. 214 215 Args: 216 obj: A Python object (class, module, instance, callable). 217 attr_name: str. The name of the attribute to replace with a mock. 218 use_mock_anything: bool. True if a MockAnything should be used regardless 219 of the type of attribute. 220 """ 221 222 attr_to_replace = getattr(obj, attr_name) 223 if type(attr_to_replace) in self._USE_MOCK_OBJECT and not use_mock_anything: 224 stub = self.CreateMock(attr_to_replace) 225 else: 226 stub = self.CreateMockAnything() 227 228 self.stubs.Set(obj, attr_name, stub) 229 230 def UnsetStubs(self): 231 """Restore stubs to their original state.""" 232 233 self.stubs.UnsetAll() 234 235def Replay(*args): 236 """Put mocks into Replay mode. 237 238 Args: 239 # args is any number of mocks to put into replay mode. 240 """ 241 242 for mock in args: 243 mock._Replay() 244 245 246def Verify(*args): 247 """Verify mocks. 248 249 Args: 250 # args is any number of mocks to be verified. 251 """ 252 253 for mock in args: 254 mock._Verify() 255 256 257def Reset(*args): 258 """Reset mocks. 259 260 Args: 261 # args is any number of mocks to be reset. 262 """ 263 264 for mock in args: 265 mock._Reset() 266 267 268class MockAnything: 269 """A mock that can be used to mock anything. 270 271 This is helpful for mocking classes that do not provide a public interface. 272 """ 273 274 def __init__(self): 275 """ """ 276 self._Reset() 277 278 def __getattr__(self, method_name): 279 """Intercept method calls on this object. 280 281 A new MockMethod is returned that is aware of the MockAnything's 282 state (record or replay). The call will be recorded or replayed 283 by the MockMethod's __call__. 284 285 Args: 286 # method name: the name of the method being called. 287 method_name: str 288 289 Returns: 290 A new MockMethod aware of MockAnything's state (record or replay). 291 """ 292 293 return self._CreateMockMethod(method_name) 294 295 def _CreateMockMethod(self, method_name): 296 """Create a new mock method call and return it. 297 298 Args: 299 # method name: the name of the method being called. 300 method_name: str 301 302 Returns: 303 A new MockMethod aware of MockAnything's state (record or replay). 304 """ 305 306 return MockMethod(method_name, self._expected_calls_queue, 307 self._replay_mode) 308 309 def __nonzero__(self): 310 """Return 1 for nonzero so the mock can be used as a conditional.""" 311 312 return 1 313 314 def __eq__(self, rhs): 315 """Provide custom logic to compare objects.""" 316 317 return (isinstance(rhs, MockAnything) and 318 self._replay_mode == rhs._replay_mode and 319 self._expected_calls_queue == rhs._expected_calls_queue) 320 321 def __ne__(self, rhs): 322 """Provide custom logic to compare objects.""" 323 324 return not self == rhs 325 326 def _Replay(self): 327 """Start replaying expected method calls.""" 328 329 self._replay_mode = True 330 331 def _Verify(self): 332 """Verify that all of the expected calls have been made. 333 334 Raises: 335 ExpectedMethodCallsError: if there are still more method calls in the 336 expected queue. 337 """ 338 339 # If the list of expected calls is not empty, raise an exception 340 if self._expected_calls_queue: 341 # The last MultipleTimesGroup is not popped from the queue. 342 if (len(self._expected_calls_queue) == 1 and 343 isinstance(self._expected_calls_queue[0], MultipleTimesGroup) and 344 self._expected_calls_queue[0].IsSatisfied()): 345 pass 346 else: 347 raise ExpectedMethodCallsError(self._expected_calls_queue) 348 349 def _Reset(self): 350 """Reset the state of this mock to record mode with an empty queue.""" 351 352 # Maintain a list of method calls we are expecting 353 self._expected_calls_queue = deque() 354 355 # Make sure we are in setup mode, not replay mode 356 self._replay_mode = False 357 358 359class MockObject(MockAnything, object): 360 """A mock object that simulates the public/protected interface of a class.""" 361 362 def __init__(self, class_to_mock): 363 """Initialize a mock object. 364 365 This determines the methods and properties of the class and stores them. 366 367 Args: 368 # class_to_mock: class to be mocked 369 class_to_mock: class 370 """ 371 372 # This is used to hack around the mixin/inheritance of MockAnything, which 373 # is not a proper object (it can be anything. :-) 374 MockAnything.__dict__['__init__'](self) 375 376 # Get a list of all the public and special methods we should mock. 377 self._known_methods = set() 378 self._known_vars = set() 379 self._class_to_mock = class_to_mock 380 for method in dir(class_to_mock): 381 if callable(getattr(class_to_mock, method)): 382 self._known_methods.add(method) 383 else: 384 self._known_vars.add(method) 385 386 def __getattr__(self, name): 387 """Intercept attribute request on this object. 388 389 If the attribute is a public class variable, it will be returned and not 390 recorded as a call. 391 392 If the attribute is not a variable, it is handled like a method 393 call. The method name is checked against the set of mockable 394 methods, and a new MockMethod is returned that is aware of the 395 MockObject's state (record or replay). The call will be recorded 396 or replayed by the MockMethod's __call__. 397 398 Args: 399 # name: the name of the attribute being requested. 400 name: str 401 402 Returns: 403 Either a class variable or a new MockMethod that is aware of the state 404 of the mock (record or replay). 405 406 Raises: 407 UnknownMethodCallError if the MockObject does not mock the requested 408 method. 409 """ 410 411 if name in self._known_vars: 412 return getattr(self._class_to_mock, name) 413 414 if name in self._known_methods: 415 return self._CreateMockMethod(name) 416 417 raise UnknownMethodCallError(name) 418 419 def __eq__(self, rhs): 420 """Provide custom logic to compare objects.""" 421 422 return (isinstance(rhs, MockObject) and 423 self._class_to_mock == rhs._class_to_mock and 424 self._replay_mode == rhs._replay_mode and 425 self._expected_calls_queue == rhs._expected_calls_queue) 426 427 def __setitem__(self, key, value): 428 """Provide custom logic for mocking classes that support item assignment. 429 430 Args: 431 key: Key to set the value for. 432 value: Value to set. 433 434 Returns: 435 Expected return value in replay mode. A MockMethod object for the 436 __setitem__ method that has already been called if not in replay mode. 437 438 Raises: 439 TypeError if the underlying class does not support item assignment. 440 UnexpectedMethodCallError if the object does not expect the call to 441 __setitem__. 442 443 """ 444 setitem = self._class_to_mock.__dict__.get('__setitem__', None) 445 446 # Verify the class supports item assignment. 447 if setitem is None: 448 raise TypeError('object does not support item assignment') 449 450 # If we are in replay mode then simply call the mock __setitem__ method. 451 if self._replay_mode: 452 return MockMethod('__setitem__', self._expected_calls_queue, 453 self._replay_mode)(key, value) 454 455 456 # Otherwise, create a mock method __setitem__. 457 return self._CreateMockMethod('__setitem__')(key, value) 458 459 def __getitem__(self, key): 460 """Provide custom logic for mocking classes that are subscriptable. 461 462 Args: 463 key: Key to return the value for. 464 465 Returns: 466 Expected return value in replay mode. A MockMethod object for the 467 __getitem__ method that has already been called if not in replay mode. 468 469 Raises: 470 TypeError if the underlying class is not subscriptable. 471 UnexpectedMethodCallError if the object does not expect the call to 472 __setitem__. 473 474 """ 475 getitem = self._class_to_mock.__dict__.get('__getitem__', None) 476 477 # Verify the class supports item assignment. 478 if getitem is None: 479 raise TypeError('unsubscriptable object') 480 481 # If we are in replay mode then simply call the mock __getitem__ method. 482 if self._replay_mode: 483 return MockMethod('__getitem__', self._expected_calls_queue, 484 self._replay_mode)(key) 485 486 487 # Otherwise, create a mock method __getitem__. 488 return self._CreateMockMethod('__getitem__')(key) 489 490 def __call__(self, *params, **named_params): 491 """Provide custom logic for mocking classes that are callable.""" 492 493 # Verify the class we are mocking is callable 494 callable = self._class_to_mock.__dict__.get('__call__', None) 495 if callable is None: 496 raise TypeError('Not callable') 497 498 # Because the call is happening directly on this object instead of a method, 499 # the call on the mock method is made right here 500 mock_method = self._CreateMockMethod('__call__') 501 return mock_method(*params, **named_params) 502 503 @property 504 def __class__(self): 505 """Return the class that is being mocked.""" 506 507 return self._class_to_mock 508 509 510class MockMethod(object): 511 """Callable mock method. 512 513 A MockMethod should act exactly like the method it mocks, accepting parameters 514 and returning a value, or throwing an exception (as specified). When this 515 method is called, it can optionally verify whether the called method (name and 516 signature) matches the expected method. 517 """ 518 519 def __init__(self, method_name, call_queue, replay_mode): 520 """Construct a new mock method. 521 522 Args: 523 # method_name: the name of the method 524 # call_queue: deque of calls, verify this call against the head, or add 525 # this call to the queue. 526 # replay_mode: False if we are recording, True if we are verifying calls 527 # against the call queue. 528 method_name: str 529 call_queue: list or deque 530 replay_mode: bool 531 """ 532 533 self._name = method_name 534 self._call_queue = call_queue 535 if not isinstance(call_queue, deque): 536 self._call_queue = deque(self._call_queue) 537 self._replay_mode = replay_mode 538 539 self._params = None 540 self._named_params = None 541 self._return_value = None 542 self._exception = None 543 self._side_effects = None 544 545 def __call__(self, *params, **named_params): 546 """Log parameters and return the specified return value. 547 548 If the Mock(Anything/Object) associated with this call is in record mode, 549 this MockMethod will be pushed onto the expected call queue. If the mock 550 is in replay mode, this will pop a MockMethod off the top of the queue and 551 verify this call is equal to the expected call. 552 553 Raises: 554 UnexpectedMethodCall if this call is supposed to match an expected method 555 call and it does not. 556 """ 557 558 self._params = params 559 self._named_params = named_params 560 561 if not self._replay_mode: 562 self._call_queue.append(self) 563 return self 564 565 expected_method = self._VerifyMethodCall() 566 567 if expected_method._side_effects: 568 expected_method._side_effects(*params, **named_params) 569 570 if expected_method._exception: 571 raise expected_method._exception 572 573 return expected_method._return_value 574 575 def __getattr__(self, name): 576 """Raise an AttributeError with a helpful message.""" 577 578 raise AttributeError('MockMethod has no attribute "%s". ' 579 'Did you remember to put your mocks in replay mode?' % name) 580 581 def _PopNextMethod(self): 582 """Pop the next method from our call queue.""" 583 try: 584 return self._call_queue.popleft() 585 except IndexError: 586 raise UnexpectedMethodCallError(self, None) 587 588 def _VerifyMethodCall(self): 589 """Verify the called method is expected. 590 591 This can be an ordered method, or part of an unordered set. 592 593 Returns: 594 The expected mock method. 595 596 Raises: 597 UnexpectedMethodCall if the method called was not expected. 598 """ 599 600 expected = self._PopNextMethod() 601 602 # Loop here, because we might have a MethodGroup followed by another 603 # group. 604 while isinstance(expected, MethodGroup): 605 expected, method = expected.MethodCalled(self) 606 if method is not None: 607 return method 608 609 # This is a mock method, so just check equality. 610 if expected != self: 611 raise UnexpectedMethodCallError(self, expected) 612 613 return expected 614 615 def __str__(self): 616 params = ', '.join( 617 [repr(p) for p in self._params or []] + 618 ['%s=%r' % x for x in sorted((self._named_params or {}).items())]) 619 desc = "%s(%s) -> %r" % (self._name, params, self._return_value) 620 return desc 621 622 def __eq__(self, rhs): 623 """Test whether this MockMethod is equivalent to another MockMethod. 624 625 Args: 626 # rhs: the right hand side of the test 627 rhs: MockMethod 628 """ 629 630 return (isinstance(rhs, MockMethod) and 631 self._name == rhs._name and 632 self._params == rhs._params and 633 self._named_params == rhs._named_params) 634 635 def __ne__(self, rhs): 636 """Test whether this MockMethod is not equivalent to another MockMethod. 637 638 Args: 639 # rhs: the right hand side of the test 640 rhs: MockMethod 641 """ 642 643 return not self == rhs 644 645 def GetPossibleGroup(self): 646 """Returns a possible group from the end of the call queue or None if no 647 other methods are on the stack. 648 """ 649 650 # Remove this method from the tail of the queue so we can add it to a group. 651 this_method = self._call_queue.pop() 652 assert this_method == self 653 654 # Determine if the tail of the queue is a group, or just a regular ordered 655 # mock method. 656 group = None 657 try: 658 group = self._call_queue[-1] 659 except IndexError: 660 pass 661 662 return group 663 664 def _CheckAndCreateNewGroup(self, group_name, group_class): 665 """Checks if the last method (a possible group) is an instance of our 666 group_class. Adds the current method to this group or creates a new one. 667 668 Args: 669 670 group_name: the name of the group. 671 group_class: the class used to create instance of this new group 672 """ 673 group = self.GetPossibleGroup() 674 675 # If this is a group, and it is the correct group, add the method. 676 if isinstance(group, group_class) and group.group_name() == group_name: 677 group.AddMethod(self) 678 return self 679 680 # Create a new group and add the method. 681 new_group = group_class(group_name) 682 new_group.AddMethod(self) 683 self._call_queue.append(new_group) 684 return self 685 686 def InAnyOrder(self, group_name="default"): 687 """Move this method into a group of unordered calls. 688 689 A group of unordered calls must be defined together, and must be executed 690 in full before the next expected method can be called. There can be 691 multiple groups that are expected serially, if they are given 692 different group names. The same group name can be reused if there is a 693 standard method call, or a group with a different name, spliced between 694 usages. 695 696 Args: 697 group_name: the name of the unordered group. 698 699 Returns: 700 self 701 """ 702 return self._CheckAndCreateNewGroup(group_name, UnorderedGroup) 703 704 def MultipleTimes(self, group_name="default"): 705 """Move this method into group of calls which may be called multiple times. 706 707 A group of repeating calls must be defined together, and must be executed in 708 full before the next expected mehtod can be called. 709 710 Args: 711 group_name: the name of the unordered group. 712 713 Returns: 714 self 715 """ 716 return self._CheckAndCreateNewGroup(group_name, MultipleTimesGroup) 717 718 def AndReturn(self, return_value): 719 """Set the value to return when this method is called. 720 721 Args: 722 # return_value can be anything. 723 """ 724 725 self._return_value = return_value 726 return return_value 727 728 def AndRaise(self, exception): 729 """Set the exception to raise when this method is called. 730 731 Args: 732 # exception: the exception to raise when this method is called. 733 exception: Exception 734 """ 735 736 self._exception = exception 737 738 def WithSideEffects(self, side_effects): 739 """Set the side effects that are simulated when this method is called. 740 741 Args: 742 side_effects: A callable which modifies the parameters or other relevant 743 state which a given test case depends on. 744 745 Returns: 746 Self for chaining with AndReturn and AndRaise. 747 """ 748 self._side_effects = side_effects 749 return self 750 751class Comparator: 752 """Base class for all Mox comparators. 753 754 A Comparator can be used as a parameter to a mocked method when the exact 755 value is not known. For example, the code you are testing might build up a 756 long SQL string that is passed to your mock DAO. You're only interested that 757 the IN clause contains the proper primary keys, so you can set your mock 758 up as follows: 759 760 mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result) 761 762 Now whatever query is passed in must contain the string 'IN (1, 2, 4, 5)'. 763 764 A Comparator may replace one or more parameters, for example: 765 # return at most 10 rows 766 mock_dao.RunQuery(StrContains('SELECT'), 10) 767 768 or 769 770 # Return some non-deterministic number of rows 771 mock_dao.RunQuery(StrContains('SELECT'), IsA(int)) 772 """ 773 774 def equals(self, rhs): 775 """Special equals method that all comparators must implement. 776 777 Args: 778 rhs: any python object 779 """ 780 781 raise NotImplementedError('method must be implemented by a subclass.') 782 783 def __eq__(self, rhs): 784 return self.equals(rhs) 785 786 def __ne__(self, rhs): 787 return not self.equals(rhs) 788 789 790class IsA(Comparator): 791 """This class wraps a basic Python type or class. It is used to verify 792 that a parameter is of the given type or class. 793 794 Example: 795 mock_dao.Connect(IsA(DbConnectInfo)) 796 """ 797 798 def __init__(self, class_name): 799 """Initialize IsA 800 801 Args: 802 class_name: basic python type or a class 803 """ 804 805 self._class_name = class_name 806 807 def equals(self, rhs): 808 """Check to see if the RHS is an instance of class_name. 809 810 Args: 811 # rhs: the right hand side of the test 812 rhs: object 813 814 Returns: 815 bool 816 """ 817 818 try: 819 return isinstance(rhs, self._class_name) 820 except TypeError: 821 # Check raw types if there was a type error. This is helpful for 822 # things like cStringIO.StringIO. 823 return type(rhs) == type(self._class_name) 824 825 def __repr__(self): 826 return str(self._class_name) 827 828class IsAlmost(Comparator): 829 """Comparison class used to check whether a parameter is nearly equal 830 to a given value. Generally useful for floating point numbers. 831 832 Example mock_dao.SetTimeout((IsAlmost(3.9))) 833 """ 834 835 def __init__(self, float_value, places=7): 836 """Initialize IsAlmost. 837 838 Args: 839 float_value: The value for making the comparison. 840 places: The number of decimal places to round to. 841 """ 842 843 self._float_value = float_value 844 self._places = places 845 846 def equals(self, rhs): 847 """Check to see if RHS is almost equal to float_value 848 849 Args: 850 rhs: the value to compare to float_value 851 852 Returns: 853 bool 854 """ 855 856 try: 857 return round(rhs-self._float_value, self._places) == 0 858 except TypeError: 859 # This is probably because either float_value or rhs is not a number. 860 return False 861 862 def __repr__(self): 863 return str(self._float_value) 864 865class StrContains(Comparator): 866 """Comparison class used to check whether a substring exists in a 867 string parameter. This can be useful in mocking a database with SQL 868 passed in as a string parameter, for example. 869 870 Example: 871 mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result) 872 """ 873 874 def __init__(self, search_string): 875 """Initialize. 876 877 Args: 878 # search_string: the string you are searching for 879 search_string: str 880 """ 881 882 self._search_string = search_string 883 884 def equals(self, rhs): 885 """Check to see if the search_string is contained in the rhs string. 886 887 Args: 888 # rhs: the right hand side of the test 889 rhs: object 890 891 Returns: 892 bool 893 """ 894 895 try: 896 return rhs.find(self._search_string) > -1 897 except Exception: 898 return False 899 900 def __repr__(self): 901 return '<str containing \'%s\'>' % self._search_string 902 903 904class Regex(Comparator): 905 """Checks if a string matches a regular expression. 906 907 This uses a given regular expression to determine equality. 908 """ 909 910 def __init__(self, pattern, flags=0): 911 """Initialize. 912 913 Args: 914 # pattern is the regular expression to search for 915 pattern: str 916 # flags passed to re.compile function as the second argument 917 flags: int 918 """ 919 920 self.regex = re.compile(pattern, flags=flags) 921 922 def equals(self, rhs): 923 """Check to see if rhs matches regular expression pattern. 924 925 Returns: 926 bool 927 """ 928 929 return self.regex.search(rhs) is not None 930 931 def __repr__(self): 932 s = '<regular expression \'%s\'' % self.regex.pattern 933 if self.regex.flags: 934 s += ', flags=%d' % self.regex.flags 935 s += '>' 936 return s 937 938 939class In(Comparator): 940 """Checks whether an item (or key) is in a list (or dict) parameter. 941 942 Example: 943 mock_dao.GetUsersInfo(In('expectedUserName')).AndReturn(mock_result) 944 """ 945 946 def __init__(self, key): 947 """Initialize. 948 949 Args: 950 # key is any thing that could be in a list or a key in a dict 951 """ 952 953 self._key = key 954 955 def equals(self, rhs): 956 """Check to see whether key is in rhs. 957 958 Args: 959 rhs: dict 960 961 Returns: 962 bool 963 """ 964 965 return self._key in rhs 966 967 def __repr__(self): 968 return '<sequence or map containing \'%s\'>' % self._key 969 970 971class ContainsKeyValue(Comparator): 972 """Checks whether a key/value pair is in a dict parameter. 973 974 Example: 975 mock_dao.UpdateUsers(ContainsKeyValue('stevepm', stevepm_user_info)) 976 """ 977 978 def __init__(self, key, value): 979 """Initialize. 980 981 Args: 982 # key: a key in a dict 983 # value: the corresponding value 984 """ 985 986 self._key = key 987 self._value = value 988 989 def equals(self, rhs): 990 """Check whether the given key/value pair is in the rhs dict. 991 992 Returns: 993 bool 994 """ 995 996 try: 997 return rhs[self._key] == self._value 998 except Exception: 999 return False 1000 1001 def __repr__(self): 1002 return '<map containing the entry \'%s: %s\'>' % (self._key, self._value) 1003 1004 1005class SameElementsAs(Comparator): 1006 """Checks whether iterables contain the same elements (ignoring order). 1007 1008 Example: 1009 mock_dao.ProcessUsers(SameElementsAs('stevepm', 'salomaki')) 1010 """ 1011 1012 def __init__(self, expected_seq): 1013 """Initialize. 1014 1015 Args: 1016 expected_seq: a sequence 1017 """ 1018 1019 self._expected_seq = expected_seq 1020 1021 def equals(self, actual_seq): 1022 """Check to see whether actual_seq has same elements as expected_seq. 1023 1024 Args: 1025 actual_seq: sequence 1026 1027 Returns: 1028 bool 1029 """ 1030 1031 try: 1032 expected = dict([(element, None) for element in self._expected_seq]) 1033 actual = dict([(element, None) for element in actual_seq]) 1034 except TypeError: 1035 # Fall back to slower list-compare if any of the objects are unhashable. 1036 expected = list(self._expected_seq) 1037 actual = list(actual_seq) 1038 expected.sort() 1039 actual.sort() 1040 return expected == actual 1041 1042 def __repr__(self): 1043 return '<sequence with same elements as \'%s\'>' % self._expected_seq 1044 1045 1046class And(Comparator): 1047 """Evaluates one or more Comparators on RHS and returns an AND of the results. 1048 """ 1049 1050 def __init__(self, *args): 1051 """Initialize. 1052 1053 Args: 1054 *args: One or more Comparator 1055 """ 1056 1057 self._comparators = args 1058 1059 def equals(self, rhs): 1060 """Checks whether all Comparators are equal to rhs. 1061 1062 Args: 1063 # rhs: can be anything 1064 1065 Returns: 1066 bool 1067 """ 1068 1069 for comparator in self._comparators: 1070 if not comparator.equals(rhs): 1071 return False 1072 1073 return True 1074 1075 def __repr__(self): 1076 return '<AND %s>' % str(self._comparators) 1077 1078 1079class Or(Comparator): 1080 """Evaluates one or more Comparators on RHS and returns an OR of the results. 1081 """ 1082 1083 def __init__(self, *args): 1084 """Initialize. 1085 1086 Args: 1087 *args: One or more Mox comparators 1088 """ 1089 1090 self._comparators = args 1091 1092 def equals(self, rhs): 1093 """Checks whether any Comparator is equal to rhs. 1094 1095 Args: 1096 # rhs: can be anything 1097 1098 Returns: 1099 bool 1100 """ 1101 1102 for comparator in self._comparators: 1103 if comparator.equals(rhs): 1104 return True 1105 1106 return False 1107 1108 def __repr__(self): 1109 return '<OR %s>' % str(self._comparators) 1110 1111 1112class Func(Comparator): 1113 """Call a function that should verify the parameter passed in is correct. 1114 1115 You may need the ability to perform more advanced operations on the parameter 1116 in order to validate it. You can use this to have a callable validate any 1117 parameter. The callable should return either True or False. 1118 1119 1120 Example: 1121 1122 def myParamValidator(param): 1123 # Advanced logic here 1124 return True 1125 1126 mock_dao.DoSomething(Func(myParamValidator), true) 1127 """ 1128 1129 def __init__(self, func): 1130 """Initialize. 1131 1132 Args: 1133 func: callable that takes one parameter and returns a bool 1134 """ 1135 1136 self._func = func 1137 1138 def equals(self, rhs): 1139 """Test whether rhs passes the function test. 1140 1141 rhs is passed into func. 1142 1143 Args: 1144 rhs: any python object 1145 1146 Returns: 1147 the result of func(rhs) 1148 """ 1149 1150 return self._func(rhs) 1151 1152 def __repr__(self): 1153 return str(self._func) 1154 1155 1156class IgnoreArg(Comparator): 1157 """Ignore an argument. 1158 1159 This can be used when we don't care about an argument of a method call. 1160 1161 Example: 1162 # Check if CastMagic is called with 3 as first arg and 'disappear' as third. 1163 mymock.CastMagic(3, IgnoreArg(), 'disappear') 1164 """ 1165 1166 def equals(self, unused_rhs): 1167 """Ignores arguments and returns True. 1168 1169 Args: 1170 unused_rhs: any python object 1171 1172 Returns: 1173 always returns True 1174 """ 1175 1176 return True 1177 1178 def __repr__(self): 1179 return '<IgnoreArg>' 1180 1181 1182class MethodGroup(object): 1183 """Base class containing common behaviour for MethodGroups.""" 1184 1185 def __init__(self, group_name): 1186 self._group_name = group_name 1187 1188 def group_name(self): 1189 return self._group_name 1190 1191 def __str__(self): 1192 return '<%s "%s">' % (self.__class__.__name__, self._group_name) 1193 1194 def AddMethod(self, mock_method): 1195 raise NotImplementedError 1196 1197 def MethodCalled(self, mock_method): 1198 raise NotImplementedError 1199 1200 def IsSatisfied(self): 1201 raise NotImplementedError 1202 1203class UnorderedGroup(MethodGroup): 1204 """UnorderedGroup holds a set of method calls that may occur in any order. 1205 1206 This construct is helpful for non-deterministic events, such as iterating 1207 over the keys of a dict. 1208 """ 1209 1210 def __init__(self, group_name): 1211 super(UnorderedGroup, self).__init__(group_name) 1212 self._methods = [] 1213 1214 def AddMethod(self, mock_method): 1215 """Add a method to this group. 1216 1217 Args: 1218 mock_method: A mock method to be added to this group. 1219 """ 1220 1221 self._methods.append(mock_method) 1222 1223 def MethodCalled(self, mock_method): 1224 """Remove a method call from the group. 1225 1226 If the method is not in the set, an UnexpectedMethodCallError will be 1227 raised. 1228 1229 Args: 1230 mock_method: a mock method that should be equal to a method in the group. 1231 1232 Returns: 1233 The mock method from the group 1234 1235 Raises: 1236 UnexpectedMethodCallError if the mock_method was not in the group. 1237 """ 1238 1239 # Check to see if this method exists, and if so, remove it from the set 1240 # and return it. 1241 for method in self._methods: 1242 if method == mock_method: 1243 # Remove the called mock_method instead of the method in the group. 1244 # The called method will match any comparators when equality is checked 1245 # during removal. The method in the group could pass a comparator to 1246 # another comparator during the equality check. 1247 self._methods.remove(mock_method) 1248 1249 # If this group is not empty, put it back at the head of the queue. 1250 if not self.IsSatisfied(): 1251 mock_method._call_queue.appendleft(self) 1252 1253 return self, method 1254 1255 raise UnexpectedMethodCallError(mock_method, self) 1256 1257 def IsSatisfied(self): 1258 """Return True if there are not any methods in this group.""" 1259 1260 return len(self._methods) == 0 1261 1262 1263class MultipleTimesGroup(MethodGroup): 1264 """MultipleTimesGroup holds methods that may be called any number of times. 1265 1266 Note: Each method must be called at least once. 1267 1268 This is helpful, if you don't know or care how many times a method is called. 1269 """ 1270 1271 def __init__(self, group_name): 1272 super(MultipleTimesGroup, self).__init__(group_name) 1273 self._methods = set() 1274 self._methods_called = set() 1275 1276 def AddMethod(self, mock_method): 1277 """Add a method to this group. 1278 1279 Args: 1280 mock_method: A mock method to be added to this group. 1281 """ 1282 1283 self._methods.add(mock_method) 1284 1285 def MethodCalled(self, mock_method): 1286 """Remove a method call from the group. 1287 1288 If the method is not in the set, an UnexpectedMethodCallError will be 1289 raised. 1290 1291 Args: 1292 mock_method: a mock method that should be equal to a method in the group. 1293 1294 Returns: 1295 The mock method from the group 1296 1297 Raises: 1298 UnexpectedMethodCallError if the mock_method was not in the group. 1299 """ 1300 1301 # Check to see if this method exists, and if so add it to the set of 1302 # called methods. 1303 1304 for method in self._methods: 1305 if method == mock_method: 1306 self._methods_called.add(mock_method) 1307 # Always put this group back on top of the queue, because we don't know 1308 # when we are done. 1309 mock_method._call_queue.appendleft(self) 1310 return self, method 1311 1312 if self.IsSatisfied(): 1313 next_method = mock_method._PopNextMethod(); 1314 return next_method, None 1315 else: 1316 raise UnexpectedMethodCallError(mock_method, self) 1317 1318 def IsSatisfied(self): 1319 """Return True if all methods in this group are called at least once.""" 1320 # NOTE(psycho): We can't use the simple set difference here because we want 1321 # to match different parameters which are considered the same e.g. IsA(str) 1322 # and some string. This solution is O(n^2) but n should be small. 1323 tmp = self._methods.copy() 1324 for called in self._methods_called: 1325 for expected in tmp: 1326 if called == expected: 1327 tmp.remove(expected) 1328 if not tmp: 1329 return True 1330 break 1331 return False 1332 1333 1334class MoxMetaTestBase(type): 1335 """Metaclass to add mox cleanup and verification to every test. 1336 1337 As the mox unit testing class is being constructed (MoxTestBase or a 1338 subclass), this metaclass will modify all test functions to call the 1339 CleanUpMox method of the test class after they finish. This means that 1340 unstubbing and verifying will happen for every test with no additional code, 1341 and any failures will result in test failures as opposed to errors. 1342 """ 1343 1344 def __init__(cls, name, bases, d): 1345 type.__init__(cls, name, bases, d) 1346 1347 # also get all the attributes from the base classes to account 1348 # for a case when test class is not the immediate child of MoxTestBase 1349 for base in bases: 1350 for attr_name in dir(base): 1351 d[attr_name] = getattr(base, attr_name) 1352 1353 for func_name, func in d.items(): 1354 if func_name.startswith('test') and callable(func): 1355 setattr(cls, func_name, MoxMetaTestBase.CleanUpTest(cls, func)) 1356 1357 @staticmethod 1358 def CleanUpTest(cls, func): 1359 """Adds Mox cleanup code to any MoxTestBase method. 1360 1361 Always unsets stubs after a test. Will verify all mocks for tests that 1362 otherwise pass. 1363 1364 Args: 1365 cls: MoxTestBase or subclass; the class whose test method we are altering. 1366 func: method; the method of the MoxTestBase test class we wish to alter. 1367 1368 Returns: 1369 The modified method. 1370 """ 1371 def new_method(self, *args, **kwargs): 1372 mox_obj = getattr(self, 'mox', None) 1373 cleanup_mox = False 1374 if mox_obj and isinstance(mox_obj, Mox): 1375 cleanup_mox = True 1376 try: 1377 func(self, *args, **kwargs) 1378 finally: 1379 if cleanup_mox: 1380 mox_obj.UnsetStubs() 1381 if cleanup_mox: 1382 mox_obj.VerifyAll() 1383 new_method.__name__ = func.__name__ 1384 new_method.__doc__ = func.__doc__ 1385 new_method.__module__ = func.__module__ 1386 return new_method 1387 1388 1389class MoxTestBase(unittest.TestCase): 1390 """Convenience test class to make stubbing easier. 1391 1392 Sets up a "mox" attribute which is an instance of Mox - any mox tests will 1393 want this. Also automatically unsets any stubs and verifies that all mock 1394 methods have been called at the end of each test, eliminating boilerplate 1395 code. 1396 """ 1397 1398 __metaclass__ = MoxMetaTestBase 1399 1400 def setUp(self): 1401 self.mox = Mox() 1402