1#!/usr/bin/env python 2# 3# Copyright 2010 Google Inc. 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# 17 18"""Testing utilities for the webapp libraries. 19 20 GetDefaultEnvironment: Method for easily setting up CGI environment. 21 RequestHandlerTestBase: Base class for setting up handler tests. 22""" 23 24__author__ = 'rafek@google.com (Rafe Kaplan)' 25 26import cStringIO 27import threading 28import urllib2 29from wsgiref import simple_server 30from wsgiref import validate 31 32from . import protojson 33from . import remote 34from . import test_util 35from . import transport 36from .webapp import service_handlers 37from .webapp.google_imports import webapp 38 39 40class TestService(remote.Service): 41 """Service used to do end to end tests with.""" 42 43 @remote.method(test_util.OptionalMessage, 44 test_util.OptionalMessage) 45 def optional_message(self, request): 46 if request.string_value: 47 request.string_value = '+%s' % request.string_value 48 return request 49 50 51def GetDefaultEnvironment(): 52 """Function for creating a default CGI environment.""" 53 return { 54 'LC_NUMERIC': 'C', 55 'wsgi.multiprocess': True, 56 'SERVER_PROTOCOL': 'HTTP/1.0', 57 'SERVER_SOFTWARE': 'Dev AppServer 0.1', 58 'SCRIPT_NAME': '', 59 'LOGNAME': 'nickjohnson', 60 'USER': 'nickjohnson', 61 'QUERY_STRING': 'foo=bar&foo=baz&foo2=123', 62 'PATH': '/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/bin/X11', 63 'LANG': 'en_US', 64 'LANGUAGE': 'en', 65 'REMOTE_ADDR': '127.0.0.1', 66 'LC_MONETARY': 'C', 67 'CONTENT_TYPE': 'application/x-www-form-urlencoded', 68 'wsgi.url_scheme': 'http', 69 'SERVER_PORT': '8080', 70 'HOME': '/home/mruser', 71 'USERNAME': 'mruser', 72 'CONTENT_LENGTH': '', 73 'USER_IS_ADMIN': '1', 74 'PYTHONPATH': '/tmp/setup', 75 'LC_TIME': 'C', 76 'HTTP_USER_AGENT': 'Mozilla/5.0 (X11; U; Linux i686 (x86_64); en-US; ' 77 'rv:1.8.1.6) Gecko/20070725 Firefox/2.0.0.6', 78 'wsgi.multithread': False, 79 'wsgi.version': (1, 0), 80 'USER_EMAIL': 'test@example.com', 81 'USER_EMAIL': '112', 82 'wsgi.input': cStringIO.StringIO(), 83 'PATH_TRANSLATED': '/tmp/request.py', 84 'SERVER_NAME': 'localhost', 85 'GATEWAY_INTERFACE': 'CGI/1.1', 86 'wsgi.run_once': True, 87 'LC_COLLATE': 'C', 88 'HOSTNAME': 'myhost', 89 'wsgi.errors': cStringIO.StringIO(), 90 'PWD': '/tmp', 91 'REQUEST_METHOD': 'GET', 92 'MAIL': '/dev/null', 93 'MAILCHECK': '0', 94 'USER_NICKNAME': 'test', 95 'HTTP_COOKIE': 'dev_appserver_login="test:test@example.com:True"', 96 'PATH_INFO': '/tmp/myhandler' 97 } 98 99 100class RequestHandlerTestBase(test_util.TestCase): 101 """Base class for writing RequestHandler tests. 102 103 To test a specific request handler override CreateRequestHandler. 104 To change the environment for that handler override GetEnvironment. 105 """ 106 107 def setUp(self): 108 """Set up test for request handler.""" 109 self.ResetHandler() 110 111 def GetEnvironment(self): 112 """Get environment. 113 114 Override for more specific configurations. 115 116 Returns: 117 dict of CGI environment. 118 """ 119 return GetDefaultEnvironment() 120 121 def CreateRequestHandler(self): 122 """Create RequestHandler instances. 123 124 Override to create more specific kinds of RequestHandler instances. 125 126 Returns: 127 RequestHandler instance used in test. 128 """ 129 return webapp.RequestHandler() 130 131 def CheckResponse(self, 132 expected_status, 133 expected_headers, 134 expected_content): 135 """Check that the web response is as expected. 136 137 Args: 138 expected_status: Expected status message. 139 expected_headers: Dictionary of expected headers. Will ignore unexpected 140 headers and only check the value of those expected. 141 expected_content: Expected body. 142 """ 143 def check_content(content): 144 self.assertEquals(expected_content, content) 145 146 def start_response(status, headers): 147 self.assertEquals(expected_status, status) 148 149 found_keys = set() 150 for name, value in headers: 151 name = name.lower() 152 try: 153 expected_value = expected_headers[name] 154 except KeyError: 155 pass 156 else: 157 found_keys.add(name) 158 self.assertEquals(expected_value, value) 159 160 missing_headers = set(expected_headers.keys()) - found_keys 161 if missing_headers: 162 self.fail('Expected keys %r not found' % (list(missing_headers),)) 163 164 return check_content 165 166 self.handler.response.wsgi_write(start_response) 167 168 def ResetHandler(self, change_environ=None): 169 """Reset this tests environment with environment changes. 170 171 Resets the entire test with a new handler which includes some changes to 172 the default request environment. 173 174 Args: 175 change_environ: Dictionary of values that are added to default 176 environment. 177 """ 178 environment = self.GetEnvironment() 179 environment.update(change_environ or {}) 180 181 self.request = webapp.Request(environment) 182 self.response = webapp.Response() 183 self.handler = self.CreateRequestHandler() 184 self.handler.initialize(self.request, self.response) 185 186 187class SyncedWSGIServer(simple_server.WSGIServer): 188 pass 189 190 191class ServerThread(threading.Thread): 192 """Thread responsible for managing wsgi server. 193 194 This server does not just attach to the socket and listen for requests. This 195 is because the server classes in Python 2.5 or less have no way to shut them 196 down. Instead, the thread must be notified of how many requests it will 197 receive so that it listens for each one individually. Tests should tell how 198 many requests to listen for using the handle_request method. 199 """ 200 201 def __init__(self, server, *args, **kwargs): 202 """Constructor. 203 204 Args: 205 server: The WSGI server that is served by this thread. 206 As per threading.Thread base class. 207 208 State: 209 __serving: Server is still expected to be serving. When False server 210 knows to shut itself down. 211 """ 212 self.server = server 213 # This timeout is for the socket when a connection is made. 214 self.server.socket.settimeout(None) 215 # This timeout is for when waiting for a connection. The allows 216 # server.handle_request() to listen for a short time, then timeout, 217 # allowing the server to check for shutdown. 218 self.server.timeout = 0.05 219 self.__serving = True 220 221 super(ServerThread, self).__init__(*args, **kwargs) 222 223 def shutdown(self): 224 """Notify server that it must shutdown gracefully.""" 225 self.__serving = False 226 227 def run(self): 228 """Handle incoming requests until shutdown.""" 229 while self.__serving: 230 self.server.handle_request() 231 232 self.server = None 233 234 235class TestService(remote.Service): 236 """Service used to do end to end tests with.""" 237 238 def __init__(self, message='uninitialized'): 239 self.__message = message 240 241 @remote.method(test_util.OptionalMessage, test_util.OptionalMessage) 242 def optional_message(self, request): 243 if request.string_value: 244 request.string_value = '+%s' % request.string_value 245 return request 246 247 @remote.method(response_type=test_util.OptionalMessage) 248 def init_parameter(self, request): 249 return test_util.OptionalMessage(string_value=self.__message) 250 251 @remote.method(test_util.NestedMessage, test_util.NestedMessage) 252 def nested_message(self, request): 253 request.string_value = '+%s' % request.string_value 254 return request 255 256 @remote.method() 257 def raise_application_error(self, request): 258 raise remote.ApplicationError('This is an application error', 'ERROR_NAME') 259 260 @remote.method() 261 def raise_unexpected_error(self, request): 262 raise TypeError('Unexpected error') 263 264 @remote.method() 265 def raise_rpc_error(self, request): 266 raise remote.NetworkError('Uncaught network error') 267 268 @remote.method(response_type=test_util.NestedMessage) 269 def return_bad_message(self, request): 270 return test_util.NestedMessage() 271 272 273class AlternateService(remote.Service): 274 """Service used to requesting non-existant methods.""" 275 276 @remote.method() 277 def does_not_exist(self, request): 278 raise NotImplementedError('Not implemented') 279 280 281class WebServerTestBase(test_util.TestCase): 282 283 SERVICE_PATH = '/my/service' 284 285 def setUp(self): 286 self.server = None 287 self.schema = 'http' 288 self.ResetServer() 289 290 self.bad_path_connection = self.CreateTransport(self.service_url + '_x') 291 self.bad_path_stub = TestService.Stub(self.bad_path_connection) 292 super(WebServerTestBase, self).setUp() 293 294 def tearDown(self): 295 self.server.shutdown() 296 super(WebServerTestBase, self).tearDown() 297 298 def ResetServer(self, application=None): 299 """Reset web server. 300 301 Shuts down existing server if necessary and starts a new one. 302 303 Args: 304 application: Optional WSGI function. If none provided will use 305 tests CreateWsgiApplication method. 306 """ 307 if self.server: 308 self.server.shutdown() 309 310 self.port = test_util.pick_unused_port() 311 self.server, self.application = self.StartWebServer(self.port, application) 312 313 self.connection = self.CreateTransport(self.service_url) 314 315 def CreateTransport(self, service_url, protocol=protojson): 316 """Create a new transportation object.""" 317 return transport.HttpTransport(service_url, protocol=protocol) 318 319 def StartWebServer(self, port, application=None): 320 """Start web server. 321 322 Args: 323 port: Port to start application on. 324 application: Optional WSGI function. If none provided will use 325 tests CreateWsgiApplication method. 326 327 Returns: 328 A tuple (server, application): 329 server: An instance of ServerThread. 330 application: Application that web server responds with. 331 """ 332 if not application: 333 application = self.CreateWsgiApplication() 334 validated_application = validate.validator(application) 335 server = simple_server.make_server('localhost', port, validated_application) 336 server = ServerThread(server) 337 server.start() 338 return server, application 339 340 def make_service_url(self, path): 341 """Make service URL using current schema and port.""" 342 return '%s://localhost:%d%s' % (self.schema, self.port, path) 343 344 @property 345 def service_url(self): 346 return self.make_service_url(self.SERVICE_PATH) 347 348 349class EndToEndTestBase(WebServerTestBase): 350 351 # Sub-classes may override to create alternate configurations. 352 DEFAULT_MAPPING = service_handlers.service_mapping( 353 [('/my/service', TestService), 354 ('/my/other_service', TestService.new_factory('initialized')), 355 ]) 356 357 def setUp(self): 358 super(EndToEndTestBase, self).setUp() 359 360 self.stub = TestService.Stub(self.connection) 361 362 self.other_connection = self.CreateTransport(self.other_service_url) 363 self.other_stub = TestService.Stub(self.other_connection) 364 365 self.mismatched_stub = AlternateService.Stub(self.connection) 366 367 @property 368 def other_service_url(self): 369 return 'http://localhost:%d/my/other_service' % self.port 370 371 def CreateWsgiApplication(self): 372 """Create WSGI application used on the server side for testing.""" 373 return webapp.WSGIApplication(self.DEFAULT_MAPPING, True) 374 375 def DoRawRequest(self, 376 method, 377 content='', 378 content_type='application/json', 379 headers=None): 380 headers = headers or {} 381 headers.update({'content-length': len(content or ''), 382 'content-type': content_type, 383 }) 384 request = urllib2.Request('%s.%s' % (self.service_url, method), 385 content, 386 headers) 387 return urllib2.urlopen(request) 388 389 def RawRequestError(self, 390 method, 391 content=None, 392 content_type='application/json', 393 headers=None): 394 try: 395 self.DoRawRequest(method, content, content_type, headers) 396 self.fail('Expected HTTP error') 397 except urllib2.HTTPError as err: 398 return err.code, err.read(), err.headers 399