Skip to content
This repository was archived by the owner on Mar 3, 2020. It is now read-only.

Commit 0382e3f

Browse files
author
Jim Abramson
committed
Merge pull request #22 from edx/jsa/implicit2
Refactor implicit flow.
2 parents 87f872c + da1cb1c commit 0382e3f

File tree

6 files changed

+282
-204
lines changed

6 files changed

+282
-204
lines changed

provider/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.3.0'
1+
__version__ = '0.4.0'

provider/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
WRITE = 1 << 2
1919
READ_WRITE = READ | WRITE
2020

21+
# NOTE that DEFAULT_SCOPES[0] (i.e. READ / 'read') is the default OAuth2 scope, per section 3.3 of rfc6749.
2122
DEFAULT_SCOPES = (
2223
(READ, 'read'),
2324
(WRITE, 'write'),

provider/oauth2/tests.py

Lines changed: 95 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import urlparse
44

5+
import ddt
56
from django.conf import settings
67
from django.contrib.auth.models import User
78
from django.core.urlresolvers import reverse
@@ -64,6 +65,7 @@ def _login_and_authorize(self, url_func=None):
6465
self.assertTrue(self.redirect_url() in response['Location'])
6566

6667

68+
@ddt.ddt
6769
class AuthorizationTest(BaseOAuth2TestCase):
6870
fixtures = ['test_oauth2']
6971

@@ -89,17 +91,40 @@ def test_authorization_requires_login(self):
8991

9092
self.assertTrue(self.auth_url2() in response['Location'])
9193

92-
def test_authorization_requires_client_id(self):
94+
@ddt.data(
95+
('read', 'read'),
96+
('write', 'write'),
97+
('read+write', 'read write read+write'),
98+
)
99+
@ddt.unpack
100+
def test_implicit_flow(self, requested_scope, expected_scope):
101+
"""
102+
End-to-end test of the implicit flow (happy path).
103+
"""
93104
self.login()
94-
response = self.client.get(self.auth_url())
105+
self.client.get(self.auth_url(), data=self.get_auth_params(response_type='token', scope=requested_scope))
106+
response = self.client.post(self.auth_url2(), {'authorize': True})
107+
fragment = urlparse.urlparse(response['Location']).fragment
108+
auth_response_data = {k: v[0] for k, v in urlparse.parse_qs(fragment).items()}
109+
self.assertEqual(auth_response_data['scope'], expected_scope)
110+
self.assertEqual(auth_response_data['access_token'], AccessToken.objects.all()[0].token)
111+
self.assertEqual(auth_response_data['token_type'], 'Bearer')
112+
self.assertEqual(int(auth_response_data['expires_in']), constants.EXPIRE_DELTA.days * 60 * 60 * 24 - 1)
113+
self.assertNotIn('refresh_token', response)
114+
115+
@ddt.data('code', 'token')
116+
def test_authorization_requires_client_id(self, response_type):
117+
self.login()
118+
self.client.get(self.auth_url(), data={'response_type': response_type})
95119
response = self.client.get(self.auth_url2())
96120

97121
self.assertEqual(400, response.status_code)
98122
self.assertTrue("An unauthorized client tried to access your resources." in response.content)
99123

100-
def test_authorization_rejects_invalid_client_id(self):
124+
@ddt.data('code', 'token')
125+
def test_authorization_rejects_invalid_client_id(self, response_type):
101126
self.login()
102-
response = self.client.get(self.auth_url(), data={"client_id": 123})
127+
response = self.client.get(self.auth_url(), data={"client_id": 123, 'response_type': response_type})
103128
response = self.client.get(self.auth_url2())
104129

105130
self.assertEqual(400, response.status_code)
@@ -113,22 +138,19 @@ def test_authorization_requires_response_type(self):
113138
self.assertEqual(400, response.status_code)
114139
self.assertTrue(escape(u"No 'response_type' supplied.") in response.content)
115140

116-
def test_authorization_requires_supported_response_type(self):
141+
@ddt.data('code', 'token', 'unsupported')
142+
def test_authorization_requires_supported_response_type(self, response_type):
117143
self.login()
118144
response = self.client.get(
119-
self.auth_url(), self.get_auth_params(response_type="unsupported"))
145+
self.auth_url(), self.get_auth_params(response_type=response_type))
120146
response = self.client.get(self.auth_url2())
121147

122-
self.assertEqual(400, response.status_code)
123-
self.assertTrue(escape(u"'unsupported' is not a supported response type.") in response.content)
148+
if response_type == 'unsupported':
149+
self.assertEqual(400, response.status_code)
150+
self.assertTrue(escape(u"'unsupported' is not a supported response type.") in response.content)
124151

125-
response = self.client.get(self.auth_url(), data=self.get_auth_params())
126-
response = self.client.get(self.auth_url2())
127-
self.assertEqual(200, response.status_code, response.content)
128-
129-
response = self.client.get(self.auth_url(), data=self.get_auth_params(response_type="token"))
130-
response = self.client.get(self.auth_url2())
131-
self.assertEqual(200, response.status_code)
152+
else:
153+
self.assertEqual(200, response.status_code)
132154

133155
def test_token_authorization_redirects_to_correct_uri(self):
134156
self.login()
@@ -212,48 +234,83 @@ def test_token_authorization_cancellation(self):
212234

213235
self.assertEqual(AccessToken.objects.count(), 0)
214236

215-
def test_authorization_requires_a_valid_redirect_uri(self):
237+
@ddt.data('code', 'token')
238+
def test_authorization_requires_a_valid_redirect_uri(self, response_type):
216239
self.login()
217240

218-
response = self.client.get(self.auth_url(),
219-
data=self.get_auth_params(redirect_uri=self.get_client().redirect_uri + '-invalid'))
241+
self.client.get(
242+
self.auth_url(),
243+
data=self.get_auth_params(
244+
response_type=response_type, redirect_uri=self.get_client().redirect_uri + '-invalid'
245+
)
246+
)
220247
response = self.client.get(self.auth_url2())
221248

222249
self.assertEqual(400, response.status_code)
223250
self.assertTrue(escape(u"The requested redirect didn't match the client settings.") in response.content)
224251

225-
response = self.client.get(self.auth_url(),
226-
data=self.get_auth_params(redirect_uri=self.get_client().redirect_uri))
252+
self.client.get(self.auth_url(), data=self.get_auth_params(
253+
response_type=response_type, redirect_uri=self.get_client().redirect_uri))
227254
response = self.client.get(self.auth_url2())
228255

229256
self.assertEqual(200, response.status_code)
230257

231-
def test_authorization_requires_a_valid_scope(self):
258+
@ddt.data('code', 'token')
259+
def test_authorization_requires_a_valid_scope(self, response_type):
232260
self.login()
233261

234-
response = self.client.get(self.auth_url(), data=self.get_auth_params(scope="invalid"))
262+
self.client.get(self.auth_url(), data=self.get_auth_params(response_type=response_type, scope="invalid"))
235263
response = self.client.get(self.auth_url2())
236264

237265
self.assertEqual(400, response.status_code)
238266
self.assertTrue(escape(u"'invalid' is not a valid scope.") in response.content,
239267
'Expected `{0}` in {1}'.format(escape(u"'invalid' is not a valid scope."), response.content))
240268

241-
response = self.client.get(self.auth_url(), data=self.get_auth_params(scope=constants.SCOPES[0][1]))
269+
self.client.get(
270+
self.auth_url(),
271+
data=self.get_auth_params(response_type=response_type, scope=constants.SCOPES[0][1])
272+
)
242273
response = self.client.get(self.auth_url2())
243274
self.assertEqual(200, response.status_code)
244275

245-
def test_authorization_is_not_granted(self):
276+
@ddt.data('code', 'token')
277+
def test_authorization_sets_default_scope(self, response_type):
278+
279+
self.login()
280+
self.client.get(self.auth_url(), data=self.get_auth_params(response_type=response_type))
281+
response = self.client.post(self.auth_url2(), {'authorize': True})
282+
283+
if response_type == 'code':
284+
# authorization code flow
285+
response = self.client.get(self.redirect_url())
286+
query = urlparse.urlparse(response['Location']).query
287+
code = urlparse.parse_qs(query)['code'][0]
288+
response = self.client.post(self.access_token_url(), {
289+
'grant_type': 'authorization_code',
290+
'client_id': self.get_client().client_id,
291+
'client_secret': self.get_client().client_secret,
292+
'code': code})
293+
scope_str = json.loads(response.content).get('scope')
294+
else:
295+
# implicit flow
296+
fragment = urlparse.urlparse(response['Location']).fragment
297+
scope_str = urlparse.parse_qs(fragment)['scope'][0]
298+
299+
self.assertEqual(scope_str, constants.SCOPES[0][1])
300+
301+
@ddt.data('code', 'token')
302+
def test_authorization_is_not_granted(self, response_type):
246303
self.login()
247304

248-
response = self.client.get(self.auth_url(), data=self.get_auth_params(response_type="code"))
249-
response = self.client.get(self.auth_url2())
305+
self.client.get(self.auth_url(), data=self.get_auth_params(response_type=response_type))
306+
self.client.get(self.auth_url2())
250307

251308
response = self.client.post(self.auth_url2(), {'authorize': False, 'scope': constants.SCOPES[0][1]})
252309
self.assertEqual(302, response.status_code, response.content)
253310
self.assertTrue(self.get_client().redirect_uri in response['Location'],
254311
'{0} not in {1}'.format(self.redirect_url(), response['Location']))
255312
self.assertTrue('error=access_denied' in response['Location'])
256-
self.assertFalse('code' in response['Location'])
313+
self.assertFalse(response_type in response['Location'])
257314

258315
def test_authorization_is_granted(self):
259316
self.login()
@@ -278,6 +335,17 @@ def test_preserving_the_state_variable(self):
278335
self.assertTrue('code' in response['Location'])
279336
self.assertTrue('state=abc' in response['Location'])
280337

338+
def test_preserving_the_state_variable_implicit(self):
339+
self.login()
340+
341+
self.client.get(self.auth_url(), data=self.get_auth_params(response_type='token', state='abc'))
342+
self.client.get(self.auth_url2())
343+
response = self.client.post(self.auth_url2(), {'authorize': True, 'scope': constants.SCOPES[0][1]})
344+
self.assertEqual(302, response.status_code)
345+
self.assertFalse('error' in response['Location'])
346+
self.assertTrue('access_token=' in response['Location'])
347+
self.assertTrue('state=abc' in response['Location'])
348+
281349
def test_redirect_requires_valid_data(self):
282350
self.login()
283351
response = self.client.get(self.redirect_url())

provider/oauth2/views.py

Lines changed: 46 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,53 @@
1313
from provider.oauth2.forms import PasswordGrantForm, RefreshTokenGrantForm
1414
from provider.oauth2.models import Client, RefreshToken, AccessToken
1515
from provider.utils import now
16-
from provider.views import AccessToken as AccessTokenView, OAuthError
16+
from provider.views import AccessToken as AccessTokenView, OAuthError, AccessTokenMixin
1717
from provider.views import Capture, Authorize, Redirect
1818

1919

20+
class OAuth2AccessTokenMixin(AccessTokenMixin):
21+
22+
def get_access_token(self, request, user, scope, client):
23+
try:
24+
# Attempt to fetch an existing access token.
25+
at = AccessToken.objects.get(user=user, client=client,
26+
scope=scope, expires__gt=now())
27+
except AccessToken.DoesNotExist:
28+
# None found... make a new one!
29+
at = self.create_access_token(request, user, scope, client)
30+
self.create_refresh_token(request, user, scope, at, client)
31+
return at
32+
33+
def create_access_token(self, request, user, scope, client):
34+
return AccessToken.objects.create(
35+
user=user,
36+
client=client,
37+
scope=scope
38+
)
39+
40+
def create_refresh_token(self, request, user, scope, access_token, client):
41+
return RefreshToken.objects.create(
42+
user=user,
43+
access_token=access_token,
44+
client=client
45+
)
46+
47+
def invalidate_refresh_token(self, rt):
48+
if constants.DELETE_EXPIRED:
49+
rt.delete()
50+
else:
51+
rt.expired = True
52+
rt.save()
53+
54+
def invalidate_access_token(self, at):
55+
if constants.DELETE_EXPIRED:
56+
at.delete()
57+
else:
58+
at.expires = now() - timedelta(milliseconds=1)
59+
at.save()
60+
61+
62+
2063
class Capture(Capture):
2164
"""
2265
Implementation of :class:`provider.views.Capture`.
@@ -26,7 +69,7 @@ def get_redirect_url(self, request):
2669
return reverse('oauth2:authorize')
2770

2871

29-
class Authorize(Authorize):
72+
class Authorize(Authorize, OAuth2AccessTokenMixin):
3073
"""
3174
Implementation of :class:`provider.views.Authorize`.
3275
"""
@@ -67,7 +110,7 @@ class Redirect(Redirect):
67110
pass
68111

69112

70-
class AccessTokenView(AccessTokenView):
113+
class AccessTokenView(AccessTokenView, OAuth2AccessTokenMixin):
71114
"""
72115
Implementation of :class:`provider.views.AccessToken`.
73116
@@ -100,52 +143,13 @@ def get_password_grant(self, request, data, client):
100143
raise OAuthError(form.errors)
101144
return form.cleaned_data
102145

103-
def get_access_token(self, request, user, scope, client):
104-
try:
105-
# Attempt to fetch an existing access token.
106-
at = AccessToken.objects.get(user=user, client=client,
107-
scope=scope, expires__gt=now())
108-
except AccessToken.DoesNotExist:
109-
# None found... make a new one!
110-
at = self.create_access_token(request, user, scope, client)
111-
self.create_refresh_token(request, user, scope, at, client)
112-
return at
113-
114-
def create_access_token(self, request, user, scope, client):
115-
return AccessToken.objects.create(
116-
user=user,
117-
client=client,
118-
scope=scope
119-
)
120-
121-
def create_refresh_token(self, request, user, scope, access_token, client):
122-
return RefreshToken.objects.create(
123-
user=user,
124-
access_token=access_token,
125-
client=client
126-
)
127-
128146
def invalidate_grant(self, grant):
129147
if constants.DELETE_EXPIRED:
130148
grant.delete()
131149
else:
132150
grant.expires = now() - timedelta(days=1)
133151
grant.save()
134152

135-
def invalidate_refresh_token(self, rt):
136-
if constants.DELETE_EXPIRED:
137-
rt.delete()
138-
else:
139-
rt.expired = True
140-
rt.save()
141-
142-
def invalidate_access_token(self, at):
143-
if constants.DELETE_EXPIRED:
144-
at.delete()
145-
else:
146-
at.expires = now() - timedelta(days=1)
147-
at.save()
148-
149153

150154
class AccessTokenDetailView(View):
151155
"""

0 commit comments

Comments
 (0)