• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python2.4
2#
3# Copyright 2014 Google Inc. All rights reserved.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#      http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18"""Discovery document tests
19
20Unit tests for objects created from discovery documents.
21"""
22
23__author__ = 'jcgregorio@google.com (Joe Gregorio)'
24
25import datetime
26import httplib2
27import json
28import os
29import time
30import unittest
31import urllib
32import urlparse
33
34import dev_appserver
35dev_appserver.fix_sys_path()
36import mock
37import webapp2
38
39from google.appengine.api import apiproxy_stub
40from google.appengine.api import apiproxy_stub_map
41from google.appengine.api import app_identity
42from google.appengine.api import memcache
43from google.appengine.api import users
44from google.appengine.api.memcache import memcache_stub
45from google.appengine.ext import db
46from google.appengine.ext import ndb
47from google.appengine.ext import testbed
48from google.appengine.runtime import apiproxy_errors
49from oauth2client import appengine
50from oauth2client import GOOGLE_TOKEN_URI
51from oauth2client.clientsecrets import _loadfile
52from oauth2client.clientsecrets import InvalidClientSecretsError
53from oauth2client.appengine import AppAssertionCredentials
54from oauth2client.appengine import CredentialsModel
55from oauth2client.appengine import CredentialsNDBModel
56from oauth2client.appengine import FlowNDBProperty
57from oauth2client.appengine import FlowProperty
58from oauth2client.appengine import OAuth2Decorator
59from oauth2client.appengine import OAuth2DecoratorFromClientSecrets
60from oauth2client.appengine import StorageByKeyName
61from oauth2client.client import AccessTokenRefreshError
62from oauth2client.client import Credentials
63from oauth2client.client import FlowExchangeError
64from oauth2client.client import OAuth2Credentials
65from oauth2client.client import flow_from_clientsecrets
66from oauth2client.client import save_to_well_known_file
67from webtest import TestApp
68
69
70DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')
71
72
73def datafile(filename):
74  return os.path.join(DATA_DIR, filename)
75
76
77def load_and_cache(existing_file, fakename, cache_mock):
78  client_type, client_info = _loadfile(datafile(existing_file))
79  cache_mock.cache[fakename] = {client_type: client_info}
80
81
82class CacheMock(object):
83  def __init__(self):
84    self.cache = {}
85
86  def get(self, key, namespace=''):
87    # ignoring namespace for easier testing
88    return self.cache.get(key, None)
89
90  def set(self, key, value, namespace=''):
91    # ignoring namespace for easier testing
92    self.cache[key] = value
93
94
95class UserMock(object):
96  """Mock the app engine user service"""
97
98  def __call__(self):
99    return self
100
101  def user_id(self):
102    return 'foo_user'
103
104
105class UserNotLoggedInMock(object):
106  """Mock the app engine user service"""
107
108  def __call__(self):
109    return None
110
111
112class Http2Mock(object):
113  """Mock httplib2.Http"""
114  status = 200
115  content = {
116      'access_token': 'foo_access_token',
117      'refresh_token': 'foo_refresh_token',
118      'expires_in': 3600,
119      'extra': 'value',
120    }
121
122  def request(self, token_uri, method, body, headers, *args, **kwargs):
123    self.body = body
124    self.headers = headers
125    return (self, json.dumps(self.content))
126
127
128class TestAppAssertionCredentials(unittest.TestCase):
129  account_name = "service_account_name@appspot.com"
130  signature = "signature"
131
132
133  class AppIdentityStubImpl(apiproxy_stub.APIProxyStub):
134
135    def __init__(self):
136      super(TestAppAssertionCredentials.AppIdentityStubImpl, self).__init__(
137          'app_identity_service')
138
139    def _Dynamic_GetAccessToken(self, request, response):
140      response.set_access_token('a_token_123')
141      response.set_expiration_time(time.time() + 1800)
142
143
144  class ErroringAppIdentityStubImpl(apiproxy_stub.APIProxyStub):
145
146    def __init__(self):
147      super(TestAppAssertionCredentials.ErroringAppIdentityStubImpl, self).__init__(
148          'app_identity_service')
149
150    def _Dynamic_GetAccessToken(self, request, response):
151      raise app_identity.BackendDeadlineExceeded()
152
153  def test_raise_correct_type_of_exception(self):
154    app_identity_stub = self.ErroringAppIdentityStubImpl()
155    apiproxy_stub_map.apiproxy = apiproxy_stub_map.APIProxyStubMap()
156    apiproxy_stub_map.apiproxy.RegisterStub('app_identity_service',
157                                            app_identity_stub)
158    apiproxy_stub_map.apiproxy.RegisterStub(
159      'memcache', memcache_stub.MemcacheServiceStub())
160
161    scope = 'http://www.googleapis.com/scope'
162    credentials = AppAssertionCredentials(scope)
163    http = httplib2.Http()
164    self.assertRaises(AccessTokenRefreshError, credentials.refresh, http)
165
166  def test_get_access_token_on_refresh(self):
167    app_identity_stub = self.AppIdentityStubImpl()
168    apiproxy_stub_map.apiproxy = apiproxy_stub_map.APIProxyStubMap()
169    apiproxy_stub_map.apiproxy.RegisterStub("app_identity_service",
170                                            app_identity_stub)
171    apiproxy_stub_map.apiproxy.RegisterStub(
172      'memcache', memcache_stub.MemcacheServiceStub())
173
174    scope = [
175     "http://www.googleapis.com/scope",
176     "http://www.googleapis.com/scope2"]
177    credentials = AppAssertionCredentials(scope)
178    http = httplib2.Http()
179    credentials.refresh(http)
180    self.assertEqual('a_token_123', credentials.access_token)
181
182    json = credentials.to_json()
183    credentials = Credentials.new_from_json(json)
184    self.assertEqual(
185      'http://www.googleapis.com/scope http://www.googleapis.com/scope2',
186      credentials.scope)
187
188    scope = "http://www.googleapis.com/scope http://www.googleapis.com/scope2"
189    credentials = AppAssertionCredentials(scope)
190    http = httplib2.Http()
191    credentials.refresh(http)
192    self.assertEqual('a_token_123', credentials.access_token)
193    self.assertEqual(
194      'http://www.googleapis.com/scope http://www.googleapis.com/scope2',
195      credentials.scope)
196
197  def test_custom_service_account(self):
198    scope = "http://www.googleapis.com/scope"
199    account_id = "service_account_name_2@appspot.com"
200
201    with mock.patch.object(app_identity, 'get_access_token',
202                           return_value=('a_token_456', None),
203                           autospec=True) as get_access_token:
204      credentials = AppAssertionCredentials(
205          scope, service_account_id=account_id)
206      http = httplib2.Http()
207      credentials.refresh(http)
208
209      self.assertEqual('a_token_456', credentials.access_token)
210      self.assertEqual(scope, credentials.scope)
211      get_access_token.assert_called_once_with(
212          [scope], service_account_id=account_id)
213
214  def test_create_scoped_required_without_scopes(self):
215    credentials = AppAssertionCredentials([])
216    self.assertTrue(credentials.create_scoped_required())
217
218  def test_create_scoped_required_with_scopes(self):
219    credentials = AppAssertionCredentials(['dummy_scope'])
220    self.assertFalse(credentials.create_scoped_required())
221
222  def test_create_scoped(self):
223    credentials = AppAssertionCredentials([])
224    new_credentials = credentials.create_scoped(['dummy_scope'])
225    self.assertNotEqual(credentials, new_credentials)
226    self.assertTrue(isinstance(new_credentials, AppAssertionCredentials))
227    self.assertEqual('dummy_scope', new_credentials.scope)
228
229  def test_get_access_token(self):
230    app_identity_stub = self.AppIdentityStubImpl()
231    apiproxy_stub_map.apiproxy = apiproxy_stub_map.APIProxyStubMap()
232    apiproxy_stub_map.apiproxy.RegisterStub("app_identity_service",
233                                            app_identity_stub)
234    apiproxy_stub_map.apiproxy.RegisterStub(
235        'memcache', memcache_stub.MemcacheServiceStub())
236
237    credentials = AppAssertionCredentials(['dummy_scope'])
238    token = credentials.get_access_token()
239    self.assertEqual('a_token_123', token.access_token)
240    self.assertEqual(None, token.expires_in)
241
242  def test_save_to_well_known_file(self):
243    credentials = AppAssertionCredentials([])
244    self.assertRaises(NotImplementedError, save_to_well_known_file, credentials)
245
246
247class TestFlowModel(db.Model):
248  flow = FlowProperty()
249
250
251class FlowPropertyTest(unittest.TestCase):
252
253  def setUp(self):
254    self.testbed = testbed.Testbed()
255    self.testbed.activate()
256    self.testbed.init_datastore_v3_stub()
257
258  def tearDown(self):
259    self.testbed.deactivate()
260
261  def test_flow_get_put(self):
262    instance = TestFlowModel(
263        flow=flow_from_clientsecrets(datafile('client_secrets.json'), 'foo',
264                                     redirect_uri='oob'),
265        key_name='foo'
266        )
267    instance.put()
268    retrieved = TestFlowModel.get_by_key_name('foo')
269
270    self.assertEqual('foo_client_id', retrieved.flow.client_id)
271
272
273class TestFlowNDBModel(ndb.Model):
274  flow = FlowNDBProperty()
275
276
277class FlowNDBPropertyTest(unittest.TestCase):
278
279  def setUp(self):
280    self.testbed = testbed.Testbed()
281    self.testbed.activate()
282    self.testbed.init_datastore_v3_stub()
283    self.testbed.init_memcache_stub()
284
285  def tearDown(self):
286    self.testbed.deactivate()
287
288  def test_flow_get_put(self):
289    instance = TestFlowNDBModel(
290        flow=flow_from_clientsecrets(datafile('client_secrets.json'), 'foo',
291                                     redirect_uri='oob'),
292        id='foo'
293        )
294    instance.put()
295    retrieved = TestFlowNDBModel.get_by_id('foo')
296
297    self.assertEqual('foo_client_id', retrieved.flow.client_id)
298
299
300def _http_request(*args, **kwargs):
301  resp = httplib2.Response({'status': '200'})
302  content = json.dumps({'access_token': 'bar'})
303
304  return resp, content
305
306
307class StorageByKeyNameTest(unittest.TestCase):
308
309  def setUp(self):
310    self.testbed = testbed.Testbed()
311    self.testbed.activate()
312    self.testbed.init_datastore_v3_stub()
313    self.testbed.init_memcache_stub()
314    self.testbed.init_user_stub()
315
316    access_token = 'foo'
317    client_id = 'some_client_id'
318    client_secret = 'cOuDdkfjxxnv+'
319    refresh_token = '1/0/a.df219fjls0'
320    token_expiry = datetime.datetime.utcnow()
321    user_agent = 'refresh_checker/1.0'
322    self.credentials = OAuth2Credentials(
323      access_token, client_id, client_secret,
324      refresh_token, token_expiry, GOOGLE_TOKEN_URI,
325      user_agent)
326
327  def tearDown(self):
328    self.testbed.deactivate()
329
330  def test_get_and_put_simple(self):
331    storage = StorageByKeyName(
332      CredentialsModel, 'foo', 'credentials')
333
334    self.assertEqual(None, storage.get())
335    self.credentials.set_store(storage)
336
337    self.credentials._refresh(_http_request)
338    credmodel = CredentialsModel.get_by_key_name('foo')
339    self.assertEqual('bar', credmodel.credentials.access_token)
340
341  def test_get_and_put_cached(self):
342    storage = StorageByKeyName(
343      CredentialsModel, 'foo', 'credentials', cache=memcache)
344
345    self.assertEqual(None, storage.get())
346    self.credentials.set_store(storage)
347
348    self.credentials._refresh(_http_request)
349    credmodel = CredentialsModel.get_by_key_name('foo')
350    self.assertEqual('bar', credmodel.credentials.access_token)
351
352    # Now remove the item from the cache.
353    memcache.delete('foo')
354
355    # Check that getting refreshes the cache.
356    credentials = storage.get()
357    self.assertEqual('bar', credentials.access_token)
358    self.assertNotEqual(None, memcache.get('foo'))
359
360    # Deleting should clear the cache.
361    storage.delete()
362    credentials = storage.get()
363    self.assertEqual(None, credentials)
364    self.assertEqual(None, memcache.get('foo'))
365
366  def test_get_and_put_set_store_on_cache_retrieval(self):
367    storage = StorageByKeyName(
368      CredentialsModel, 'foo', 'credentials', cache=memcache)
369
370    self.assertEqual(None, storage.get())
371    self.credentials.set_store(storage)
372    storage.put(self.credentials)
373    # Pre-bug 292 old_creds wouldn't have storage, and the _refresh wouldn't
374    # be able to store the updated cred back into the storage.
375    old_creds = storage.get()
376    self.assertEqual(old_creds.access_token, 'foo')
377    old_creds.invalid = True
378    old_creds._refresh(_http_request)
379    new_creds = storage.get()
380    self.assertEqual(new_creds.access_token, 'bar')
381
382  def test_get_and_put_ndb(self):
383    # Start empty
384    storage = StorageByKeyName(
385      CredentialsNDBModel, 'foo', 'credentials')
386    self.assertEqual(None, storage.get())
387
388    # Refresh storage and retrieve without using storage
389    self.credentials.set_store(storage)
390    self.credentials._refresh(_http_request)
391    credmodel = CredentialsNDBModel.get_by_id('foo')
392    self.assertEqual('bar', credmodel.credentials.access_token)
393    self.assertEqual(credmodel.credentials.to_json(),
394                     self.credentials.to_json())
395
396  def test_delete_ndb(self):
397    # Start empty
398    storage = StorageByKeyName(
399      CredentialsNDBModel, 'foo', 'credentials')
400    self.assertEqual(None, storage.get())
401
402    # Add credentials to model with storage, and check equivalent w/o storage
403    storage.put(self.credentials)
404    credmodel = CredentialsNDBModel.get_by_id('foo')
405    self.assertEqual(credmodel.credentials.to_json(),
406                     self.credentials.to_json())
407
408    # Delete and make sure empty
409    storage.delete()
410    self.assertEqual(None, storage.get())
411
412  def test_get_and_put_mixed_ndb_storage_db_get(self):
413    # Start empty
414    storage = StorageByKeyName(
415      CredentialsNDBModel, 'foo', 'credentials')
416    self.assertEqual(None, storage.get())
417
418    # Set NDB store and refresh to add to storage
419    self.credentials.set_store(storage)
420    self.credentials._refresh(_http_request)
421
422    # Retrieve same key from DB model to confirm mixing works
423    credmodel = CredentialsModel.get_by_key_name('foo')
424    self.assertEqual('bar', credmodel.credentials.access_token)
425    self.assertEqual(self.credentials.to_json(),
426                     credmodel.credentials.to_json())
427
428  def test_get_and_put_mixed_db_storage_ndb_get(self):
429    # Start empty
430    storage = StorageByKeyName(
431      CredentialsModel, 'foo', 'credentials')
432    self.assertEqual(None, storage.get())
433
434    # Set DB store and refresh to add to storage
435    self.credentials.set_store(storage)
436    self.credentials._refresh(_http_request)
437
438    # Retrieve same key from NDB model to confirm mixing works
439    credmodel = CredentialsNDBModel.get_by_id('foo')
440    self.assertEqual('bar', credmodel.credentials.access_token)
441    self.assertEqual(self.credentials.to_json(),
442                     credmodel.credentials.to_json())
443
444  def test_delete_db_ndb_mixed(self):
445    # Start empty
446    storage_ndb = StorageByKeyName(
447      CredentialsNDBModel, 'foo', 'credentials')
448    storage = StorageByKeyName(
449      CredentialsModel, 'foo', 'credentials')
450
451    # First DB, then NDB
452    self.assertEqual(None, storage.get())
453    storage.put(self.credentials)
454    self.assertNotEqual(None, storage.get())
455
456    storage_ndb.delete()
457    self.assertEqual(None, storage.get())
458
459    # First NDB, then DB
460    self.assertEqual(None, storage_ndb.get())
461    storage_ndb.put(self.credentials)
462
463    storage.delete()
464    self.assertNotEqual(None, storage_ndb.get())
465    # NDB uses memcache and an instance cache (Context)
466    ndb.get_context().clear_cache()
467    memcache.flush_all()
468    self.assertEqual(None, storage_ndb.get())
469
470
471class MockRequest(object):
472  url = 'https://example.org'
473
474  def relative_url(self, rel):
475    return self.url + rel
476
477
478class MockRequestHandler(object):
479  request = MockRequest()
480
481
482class DecoratorTests(unittest.TestCase):
483
484  def setUp(self):
485    self.testbed = testbed.Testbed()
486    self.testbed.activate()
487    self.testbed.init_datastore_v3_stub()
488    self.testbed.init_memcache_stub()
489    self.testbed.init_user_stub()
490
491    decorator = OAuth2Decorator(client_id='foo_client_id',
492                                client_secret='foo_client_secret',
493                                scope=['foo_scope', 'bar_scope'],
494                                user_agent='foo')
495
496    self._finish_setup(decorator, user_mock=UserMock)
497
498  def _finish_setup(self, decorator, user_mock):
499    self.decorator = decorator
500    self.had_credentials = False
501    self.found_credentials = None
502    self.should_raise = False
503    parent = self
504
505    class TestRequiredHandler(webapp2.RequestHandler):
506      @decorator.oauth_required
507      def get(self):
508        if decorator.has_credentials():
509          parent.had_credentials = True
510          parent.found_credentials = decorator.credentials
511        if parent.should_raise:
512          raise Exception('')
513
514    class TestAwareHandler(webapp2.RequestHandler):
515      @decorator.oauth_aware
516      def get(self, *args, **kwargs):
517        self.response.out.write('Hello World!')
518        assert(kwargs['year'] == '2012')
519        assert(kwargs['month'] == '01')
520        if decorator.has_credentials():
521          parent.had_credentials = True
522          parent.found_credentials = decorator.credentials
523        if parent.should_raise:
524          raise Exception('')
525
526
527    application = webapp2.WSGIApplication([
528        ('/oauth2callback', self.decorator.callback_handler()),
529        ('/foo_path', TestRequiredHandler),
530        webapp2.Route(r'/bar_path/<year:\d{4}>/<month:\d{2}>',
531          handler=TestAwareHandler, name='bar')],
532      debug=True)
533    self.app = TestApp(application, extra_environ={
534        'wsgi.url_scheme': 'http',
535        'HTTP_HOST': 'localhost',
536        })
537    self.current_user = user_mock()
538    users.get_current_user = self.current_user
539    self.httplib2_orig = httplib2.Http
540    httplib2.Http = Http2Mock
541
542  def tearDown(self):
543    self.testbed.deactivate()
544    httplib2.Http = self.httplib2_orig
545
546  def test_required(self):
547    # An initial request to an oauth_required decorated path should be a
548    # redirect to start the OAuth dance.
549    self.assertEqual(self.decorator.flow, None)
550    self.assertEqual(self.decorator.credentials, None)
551    response = self.app.get('http://localhost/foo_path')
552    self.assertTrue(response.status.startswith('302'))
553    q = urlparse.parse_qs(response.headers['Location'].split('?', 1)[1])
554    self.assertEqual('http://localhost/oauth2callback', q['redirect_uri'][0])
555    self.assertEqual('foo_client_id', q['client_id'][0])
556    self.assertEqual('foo_scope bar_scope', q['scope'][0])
557    self.assertEqual('http://localhost/foo_path',
558                     q['state'][0].rsplit(':', 1)[0])
559    self.assertEqual('code', q['response_type'][0])
560    self.assertEqual(False, self.decorator.has_credentials())
561
562    with mock.patch.object(appengine, '_parse_state_value',
563                           return_value='foo_path',
564                           autospec=True) as parse_state_value:
565      # Now simulate the callback to /oauth2callback.
566      response = self.app.get('/oauth2callback', {
567          'code': 'foo_access_code',
568          'state': 'foo_path:xsrfkey123',
569          })
570      parts = response.headers['Location'].split('?', 1)
571      self.assertEqual('http://localhost/foo_path', parts[0])
572      self.assertEqual(None, self.decorator.credentials)
573      if self.decorator._token_response_param:
574        response_query = urlparse.parse_qs(parts[1])
575        response = response_query[self.decorator._token_response_param][0]
576        self.assertEqual(Http2Mock.content,
577                         json.loads(urllib.unquote(response)))
578      self.assertEqual(self.decorator.flow, self.decorator._tls.flow)
579      self.assertEqual(self.decorator.credentials,
580                       self.decorator._tls.credentials)
581
582      parse_state_value.assert_called_once_with(
583          'foo_path:xsrfkey123', self.current_user)
584
585    # Now requesting the decorated path should work.
586    response = self.app.get('/foo_path')
587    self.assertEqual('200 OK', response.status)
588    self.assertEqual(True, self.had_credentials)
589    self.assertEqual('foo_refresh_token',
590                     self.found_credentials.refresh_token)
591    self.assertEqual('foo_access_token',
592                     self.found_credentials.access_token)
593    self.assertEqual(None, self.decorator.credentials)
594
595    # Raising an exception still clears the Credentials.
596    self.should_raise = True
597    self.assertRaises(Exception, self.app.get, '/foo_path')
598    self.should_raise = False
599    self.assertEqual(None, self.decorator.credentials)
600
601    # Invalidate the stored Credentials.
602    self.found_credentials.invalid = True
603    self.found_credentials.store.put(self.found_credentials)
604
605    # Invalid Credentials should start the OAuth dance again.
606    response = self.app.get('/foo_path')
607    self.assertTrue(response.status.startswith('302'))
608    q = urlparse.parse_qs(response.headers['Location'].split('?', 1)[1])
609    self.assertEqual('http://localhost/oauth2callback', q['redirect_uri'][0])
610
611  def test_storage_delete(self):
612    # An initial request to an oauth_required decorated path should be a
613    # redirect to start the OAuth dance.
614    response = self.app.get('/foo_path')
615    self.assertTrue(response.status.startswith('302'))
616
617    with mock.patch.object(appengine, '_parse_state_value',
618                           return_value='foo_path',
619                           autospec=True) as parse_state_value:
620      # Now simulate the callback to /oauth2callback.
621      response = self.app.get('/oauth2callback', {
622          'code': 'foo_access_code',
623          'state': 'foo_path:xsrfkey123',
624      })
625      self.assertEqual('http://localhost/foo_path', response.headers['Location'])
626      self.assertEqual(None, self.decorator.credentials)
627
628      # Now requesting the decorated path should work.
629      response = self.app.get('/foo_path')
630
631      self.assertTrue(self.had_credentials)
632
633      # Credentials should be cleared after each call.
634      self.assertEqual(None, self.decorator.credentials)
635
636      # Invalidate the stored Credentials.
637      self.found_credentials.store.delete()
638
639      # Invalid Credentials should start the OAuth dance again.
640      response = self.app.get('/foo_path')
641      self.assertTrue(response.status.startswith('302'))
642
643      parse_state_value.assert_called_once_with(
644          'foo_path:xsrfkey123', self.current_user)
645
646  def test_aware(self):
647    # An initial request to an oauth_aware decorated path should not redirect.
648    response = self.app.get('http://localhost/bar_path/2012/01')
649    self.assertEqual('Hello World!', response.body)
650    self.assertEqual('200 OK', response.status)
651    self.assertEqual(False, self.decorator.has_credentials())
652    url = self.decorator.authorize_url()
653    q = urlparse.parse_qs(url.split('?', 1)[1])
654    self.assertEqual('http://localhost/oauth2callback', q['redirect_uri'][0])
655    self.assertEqual('foo_client_id', q['client_id'][0])
656    self.assertEqual('foo_scope bar_scope', q['scope'][0])
657    self.assertEqual('http://localhost/bar_path/2012/01',
658                     q['state'][0].rsplit(':', 1)[0])
659    self.assertEqual('code', q['response_type'][0])
660
661    with mock.patch.object(appengine, '_parse_state_value',
662                           return_value='bar_path',
663                           autospec=True) as parse_state_value:
664      # Now simulate the callback to /oauth2callback.
665      url = self.decorator.authorize_url()
666      response = self.app.get('/oauth2callback', {
667          'code': 'foo_access_code',
668          'state': 'bar_path:xsrfkey456',
669          })
670
671      self.assertEqual('http://localhost/bar_path', response.headers['Location'])
672      self.assertEqual(False, self.decorator.has_credentials())
673      parse_state_value.assert_called_once_with(
674          'bar_path:xsrfkey456', self.current_user)
675
676    # Now requesting the decorated path will have credentials.
677    response = self.app.get('/bar_path/2012/01')
678    self.assertEqual('200 OK', response.status)
679    self.assertEqual('Hello World!', response.body)
680    self.assertEqual(True, self.had_credentials)
681    self.assertEqual('foo_refresh_token',
682                     self.found_credentials.refresh_token)
683    self.assertEqual('foo_access_token',
684                     self.found_credentials.access_token)
685
686    # Credentials should be cleared after each call.
687    self.assertEqual(None, self.decorator.credentials)
688
689    # Raising an exception still clears the Credentials.
690    self.should_raise = True
691    self.assertRaises(Exception, self.app.get, '/bar_path/2012/01')
692    self.should_raise = False
693    self.assertEqual(None, self.decorator.credentials)
694
695
696  def test_error_in_step2(self):
697    # An initial request to an oauth_aware decorated path should not redirect.
698    response = self.app.get('/bar_path/2012/01')
699    url = self.decorator.authorize_url()
700    response = self.app.get('/oauth2callback', {
701        'error': 'Bad<Stuff>Happened\''
702        })
703    self.assertEqual('200 OK', response.status)
704    self.assertTrue('Bad&lt;Stuff&gt;Happened&#39;' in response.body)
705
706  def test_kwargs_are_passed_to_underlying_flow(self):
707    decorator = OAuth2Decorator(client_id='foo_client_id',
708        client_secret='foo_client_secret',
709        user_agent='foo_user_agent',
710        scope=['foo_scope', 'bar_scope'],
711        access_type='offline',
712        approval_prompt='force',
713        revoke_uri='dummy_revoke_uri')
714    request_handler = MockRequestHandler()
715    decorator._create_flow(request_handler)
716
717    self.assertEqual('https://example.org/oauth2callback',
718                     decorator.flow.redirect_uri)
719    self.assertEqual('offline', decorator.flow.params['access_type'])
720    self.assertEqual('force', decorator.flow.params['approval_prompt'])
721    self.assertEqual('foo_user_agent', decorator.flow.user_agent)
722    self.assertEqual('dummy_revoke_uri', decorator.flow.revoke_uri)
723    self.assertEqual(None, decorator.flow.params.get('user_agent', None))
724    self.assertEqual(decorator.flow, decorator._tls.flow)
725
726  def test_token_response_param(self):
727    self.decorator._token_response_param = 'foobar'
728    self.test_required()
729
730  def test_decorator_from_client_secrets(self):
731    decorator = OAuth2DecoratorFromClientSecrets(
732        datafile('client_secrets.json'),
733        scope=['foo_scope', 'bar_scope'])
734    self._finish_setup(decorator, user_mock=UserMock)
735
736    self.assertFalse(decorator._in_error)
737    self.decorator = decorator
738    self.test_required()
739    http = self.decorator.http()
740    self.assertEquals('foo_access_token', http.request.credentials.access_token)
741
742    # revoke_uri is not required
743    self.assertEqual(self.decorator._revoke_uri,
744                     'https://accounts.google.com/o/oauth2/revoke')
745    self.assertEqual(self.decorator._revoke_uri,
746                     self.decorator.credentials.revoke_uri)
747
748  def test_decorator_from_client_secrets_kwargs(self):
749    decorator = OAuth2DecoratorFromClientSecrets(
750        datafile('client_secrets.json'),
751        scope=['foo_scope', 'bar_scope'],
752        approval_prompt='force')
753    self.assertTrue('approval_prompt' in decorator._kwargs)
754
755
756  def test_decorator_from_cached_client_secrets(self):
757    cache_mock = CacheMock()
758    load_and_cache('client_secrets.json', 'secret', cache_mock)
759    decorator = OAuth2DecoratorFromClientSecrets(
760      # filename, scope, message=None, cache=None
761      'secret', '', cache=cache_mock)
762    self.assertFalse(decorator._in_error)
763
764  def test_decorator_from_client_secrets_not_logged_in_required(self):
765    decorator = OAuth2DecoratorFromClientSecrets(
766        datafile('client_secrets.json'),
767        scope=['foo_scope', 'bar_scope'], message='NotLoggedInMessage')
768    self.decorator = decorator
769    self._finish_setup(decorator, user_mock=UserNotLoggedInMock)
770
771    self.assertFalse(decorator._in_error)
772
773    # An initial request to an oauth_required decorated path should be a
774    # redirect to login.
775    response = self.app.get('/foo_path')
776    self.assertTrue(response.status.startswith('302'))
777    self.assertTrue('Login' in str(response))
778
779  def test_decorator_from_client_secrets_not_logged_in_aware(self):
780    decorator = OAuth2DecoratorFromClientSecrets(
781        datafile('client_secrets.json'),
782        scope=['foo_scope', 'bar_scope'], message='NotLoggedInMessage')
783    self.decorator = decorator
784    self._finish_setup(decorator, user_mock=UserNotLoggedInMock)
785
786    # An initial request to an oauth_aware decorated path should be a
787    # redirect to login.
788    response = self.app.get('/bar_path/2012/03')
789    self.assertTrue(response.status.startswith('302'))
790    self.assertTrue('Login' in str(response))
791
792  def test_decorator_from_unfilled_client_secrets_required(self):
793    MESSAGE = 'File is missing'
794    try:
795      decorator = OAuth2DecoratorFromClientSecrets(
796          datafile('unfilled_client_secrets.json'),
797          scope=['foo_scope', 'bar_scope'], message=MESSAGE)
798    except InvalidClientSecretsError:
799      pass
800
801  def test_decorator_from_unfilled_client_secrets_aware(self):
802    MESSAGE = 'File is missing'
803    try:
804      decorator = OAuth2DecoratorFromClientSecrets(
805          datafile('unfilled_client_secrets.json'),
806          scope=['foo_scope', 'bar_scope'], message=MESSAGE)
807    except InvalidClientSecretsError:
808      pass
809
810
811class DecoratorXsrfSecretTests(unittest.TestCase):
812  """Test xsrf_secret_key."""
813
814  def setUp(self):
815    self.testbed = testbed.Testbed()
816    self.testbed.activate()
817    self.testbed.init_datastore_v3_stub()
818    self.testbed.init_memcache_stub()
819
820  def tearDown(self):
821    self.testbed.deactivate()
822
823  def test_build_and_parse_state(self):
824    secret = appengine.xsrf_secret_key()
825
826    # Secret shouldn't change from call to call.
827    secret2 = appengine.xsrf_secret_key()
828    self.assertEqual(secret, secret2)
829
830    # Secret shouldn't change if memcache goes away.
831    memcache.delete(appengine.XSRF_MEMCACHE_ID,
832                    namespace=appengine.OAUTH2CLIENT_NAMESPACE)
833    secret3 = appengine.xsrf_secret_key()
834    self.assertEqual(secret2, secret3)
835
836    # Secret should change if both memcache and the model goes away.
837    memcache.delete(appengine.XSRF_MEMCACHE_ID,
838                    namespace=appengine.OAUTH2CLIENT_NAMESPACE)
839    model = appengine.SiteXsrfSecretKey.get_or_insert('site')
840    model.delete()
841
842    secret4 = appengine.xsrf_secret_key()
843    self.assertNotEqual(secret3, secret4)
844
845  def test_ndb_insert_db_get(self):
846    secret = appengine._generate_new_xsrf_secret_key()
847    appengine.SiteXsrfSecretKeyNDB(id='site', secret=secret).put()
848
849    site_key = appengine.SiteXsrfSecretKey.get_by_key_name('site')
850    self.assertEqual(site_key.secret, secret)
851
852  def test_db_insert_ndb_get(self):
853    secret = appengine._generate_new_xsrf_secret_key()
854    appengine.SiteXsrfSecretKey(key_name='site', secret=secret).put()
855
856    site_key = appengine.SiteXsrfSecretKeyNDB.get_by_id('site')
857    self.assertEqual(site_key.secret, secret)
858
859
860class DecoratorXsrfProtectionTests(unittest.TestCase):
861  """Test _build_state_value and _parse_state_value."""
862
863  def setUp(self):
864    self.testbed = testbed.Testbed()
865    self.testbed.activate()
866    self.testbed.init_datastore_v3_stub()
867    self.testbed.init_memcache_stub()
868
869  def tearDown(self):
870    self.testbed.deactivate()
871
872  def test_build_and_parse_state(self):
873    state = appengine._build_state_value(MockRequestHandler(), UserMock())
874    self.assertEqual(
875        'https://example.org',
876        appengine._parse_state_value(state, UserMock()))
877    self.assertRaises(appengine.InvalidXsrfTokenError,
878                      appengine._parse_state_value, state[1:], UserMock())
879
880
881if __name__ == '__main__':
882  unittest.main()
883