Skip to content
This repository was archived by the owner on Mar 3, 2020. It is now read-only.

Commit e709d71

Browse files
committed
Merge pull request #25 from edx/clintonb/client-credentials
Added support for client_credentials grant type
2 parents d672deb + a636a72 commit e709d71

File tree

5 files changed

+162
-29
lines changed

5 files changed

+162
-29
lines changed

provider/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.4.0'
1+
__version__ = '0.5.0'

provider/oauth2/forms.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from django.utils.encoding import smart_unicode
44
from django.utils.translation import ugettext as _
55

6-
from provider import scope
6+
from provider import constants, scope
77
from provider.constants import RESPONSE_TYPE_CHOICES, SCOPES
88
from provider.forms import OAuthForm, OAuthValidationError
99
from provider.oauth2.models import Client, Grant, RefreshToken
@@ -336,3 +336,14 @@ def clean(self):
336336

337337
data['client'] = client
338338
return data
339+
340+
341+
class ClientCredentialsGrantForm(ScopeMixin, OAuthForm):
342+
""" Validate a client credentials grant request. """
343+
344+
def clean(self):
345+
cleaned_data = super(ClientCredentialsGrantForm, self).clean()
346+
# We do not fully support scopes for this grant type; however, a scope is required
347+
# in order to create an access token. Default to read-only access.
348+
cleaned_data['scope'] = constants.READ
349+
return cleaned_data

provider/oauth2/tests.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,71 @@ def test_access_token_response_valid_token_type(self):
588588
self.assertEqual(token['token_type'], constants.TOKEN_TYPE, token)
589589

590590

591+
@ddt.ddt
592+
class ClientCredentialsAccessTokenTests(BaseOAuth2TestCase):
593+
""" Tests for issuing access tokens using the client credentials grant. """
594+
fixtures = ['test_oauth2.json']
595+
596+
def setUp(self):
597+
super(ClientCredentialsAccessTokenTests, self).setUp()
598+
AccessToken.objects.all().delete()
599+
600+
def request_access_token(self, client_id=None, client_secret=None):
601+
""" Issues an access token request using the client credentials grant.
602+
603+
Arguments:
604+
client_id (str): Optional override of the client ID credential.
605+
client_secret (str): Optional override of the client secret credential.
606+
607+
Returns:
608+
HttpResponse
609+
"""
610+
client = self.get_client()
611+
data = {
612+
'grant_type': 'client_credentials',
613+
'client_id': client_id or client.client_id,
614+
'client_secret': client_secret or client.client_secret,
615+
}
616+
617+
return self.client.post(self.access_token_url(), data)
618+
619+
def assert_valid_access_token_response(self, access_token, response):
620+
""" Verifies the content of the response contains a JSON representation of the access token.
621+
622+
Note:
623+
The access token should NOT have an associated refresh token.
624+
"""
625+
expected = {
626+
u'access_token': access_token.token,
627+
u'token_type': constants.TOKEN_TYPE,
628+
u'expires_in': access_token.get_expire_delta(),
629+
u'scope': u' '.join(scope.names(access_token.scope)),
630+
}
631+
632+
self.assertEqual(json.loads(response.content), expected)
633+
634+
def get_latest_access_token(self):
635+
return AccessToken.objects.filter(client=self.get_client()).order_by('-id')[0]
636+
637+
def test_authorize_success(self):
638+
""" Verify the endpoint successfully issues an access token using the client credentials grant. """
639+
response = self.request_access_token()
640+
self.assertEqual(200, response.status_code, response.content)
641+
642+
access_token = self.get_latest_access_token()
643+
self.assert_valid_access_token_response(access_token, response)
644+
645+
@ddt.data(
646+
{'client_id': 'invalid'},
647+
{'client_secret': 'invalid'},
648+
)
649+
def test_authorize_with_invalid_credentials(self, credentials_override):
650+
""" Verify the endpoint returns HTTP 400 if the credentials are invalid. """
651+
response = self.request_access_token(**credentials_override)
652+
self.assertEqual(400, response.status_code, response.content)
653+
self.assertDictEqual(json.loads(response.content), {'error': 'invalid_client'})
654+
655+
591656
class AuthBackendTest(BaseOAuth2TestCase):
592657
fixtures = ['test_oauth2']
593658

provider/oauth2/views.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from provider import constants
1010
from provider.oauth2.backends import BasicClientBackend, RequestParamsClientBackend, PublicPasswordBackend
1111
from provider.oauth2.forms import (AuthorizationCodeGrantForm, AuthorizationRequestForm, AuthorizationForm,
12-
PasswordGrantForm, RefreshTokenGrantForm)
12+
PasswordGrantForm, RefreshTokenGrantForm, ClientCredentialsGrantForm)
1313
from provider.oauth2.models import Client, RefreshToken, AccessToken
1414
from provider.utils import now
1515
from provider.views import AccessToken as AccessTokenView, OAuthError, AccessTokenMixin, Capture, Authorize, Redirect
@@ -24,7 +24,6 @@ def get_access_token(self, request, user, scope, client):
2424
except AccessToken.DoesNotExist:
2525
# None found... make a new one!
2626
at = self.create_access_token(request, user, scope, client)
27-
self.create_refresh_token(request, user, scope, at, client)
2827
return at
2928

3029
def create_access_token(self, request, user, scope, client):
@@ -140,6 +139,12 @@ def get_password_grant(self, request, data, client):
140139
raise OAuthError(form.errors)
141140
return form.cleaned_data
142141

142+
def get_client_credentials_grant(self, request, data, client):
143+
form = ClientCredentialsGrantForm(data, client=client)
144+
if not form.is_valid():
145+
raise OAuthError(form.errors)
146+
return form.cleaned_data
147+
143148
def invalidate_grant(self, grant):
144149
if constants.DELETE_EXPIRED:
145150
grant.delete()

provider/views.py

Lines changed: 77 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,25 @@ def access_token_response_data(self, access_token, response_type=None):
105105

106106
return response_data
107107

108+
def get_access_and_refresh_tokens(self, request, user, scope, client, reuse_existing_access_token=False, create_refresh_token=True):
109+
"""
110+
Returns an AccessToken and RefreshToken for the given user, scope, and client combination.
111+
112+
Returns:
113+
(AccessToken, RefreshToken)
114+
If create_refresh_token is False, the second element of the tuple will be None.
115+
"""
116+
if reuse_existing_access_token:
117+
at = self.get_access_token(request, user, scope, client)
118+
else:
119+
at = self.create_access_token(request, user, scope, client)
120+
121+
rt = None
122+
if create_refresh_token and not reuse_existing_access_token:
123+
rt = self.create_refresh_token(request, user, scope, at, client)
124+
125+
return at, rt
126+
108127

109128
class OAuthView(TemplateView):
110129
"""
@@ -333,15 +352,14 @@ def get_implicit_response(self, request, client):
333352
data = self.get_data(request)
334353

335354
lookup_kwargs = {
336-
"user": request.user,
337-
"client": client,
338-
"scope": scope.to_int(*data.get('scope', constants.SCOPES[0][1]).split())
355+
'user': request.user,
356+
'client': client,
357+
'scope': scope.to_int(*data.get('scope', constants.SCOPES[0][1]).split()),
358+
'reuse_existing_access_token': constants.SINGLE_ACCESS_TOKEN,
359+
'create_refresh_token': False
339360
}
340361

341-
if constants.SINGLE_ACCESS_TOKEN:
342-
token = self.get_access_token(request, **lookup_kwargs)
343-
else:
344-
token = self.create_access_token(request, **lookup_kwargs)
362+
token, __ = self.get_access_and_refresh_tokens(request, **lookup_kwargs)
345363

346364
response_data = self.access_token_response_data(token, data['response_type'])
347365

@@ -503,7 +521,7 @@ class AccessToken(OAuthView, Mixin, AccessTokenMixin):
503521
Authentication backends used to authenticate a particular client.
504522
"""
505523

506-
grant_types = ['authorization_code', 'refresh_token', 'password']
524+
grant_types = ['authorization_code', 'refresh_token', 'password', 'client_credentials']
507525
"""
508526
The default grant types supported by this view.
509527
"""
@@ -532,6 +550,14 @@ def get_password_grant(self, request, data, client):
532550
"""
533551
raise NotImplementedError # pragma: no cover
534552

553+
def get_client_credentials_grant(self, request, data, client):
554+
"""
555+
Return the optional parameters (scope) associated with this request.
556+
557+
:return: ``tuple`` - ``(True or False, options)``
558+
"""
559+
raise NotImplementedError # pragma: no cover
560+
535561
def invalidate_grant(self, grant):
536562
"""
537563
Override to handle grant invalidation. A grant is invalidated right
@@ -564,13 +590,16 @@ def authorization_code(self, request, data, client):
564590
Handle ``grant_type=authorization_code`` requests as defined in
565591
:rfc:`4.1.3`.
566592
"""
567-
grant = self.get_authorization_code_grant(request, request.POST,
568-
client)
569-
if constants.SINGLE_ACCESS_TOKEN:
570-
at = self.get_access_token(request, grant.user, grant.scope, client)
571-
else:
572-
at = self.create_access_token(request, grant.user, grant.scope, client)
573-
rt = self.create_refresh_token(request, grant.user, grant.scope, at, client)
593+
grant = self.get_authorization_code_grant(request, request.POST, client)
594+
595+
kwargs = {
596+
'request': request,
597+
'user': grant.user,
598+
'scope': grant.scope,
599+
'client': client,
600+
'reuse_existing_access_token': constants.SINGLE_ACCESS_TOKEN,
601+
}
602+
at, rt = self.get_access_and_refresh_tokens(**kwargs)
574603

575604
self.invalidate_grant(grant)
576605

@@ -586,8 +615,13 @@ def refresh_token(self, request, data, client):
586615
self.invalidate_refresh_token(rt)
587616
self.invalidate_access_token(rt.access_token)
588617

589-
at = self.create_access_token(request, rt.user, rt.access_token.scope, client)
590-
rt = self.create_refresh_token(request, at.user, at.scope, at, client)
618+
kwargs = {
619+
'request': request,
620+
'user': rt.user,
621+
'scope': rt.access_token.scope,
622+
'client': client,
623+
}
624+
at, rt = self.get_access_and_refresh_tokens(**kwargs)
591625

592626
return self.access_token_response(at)
593627

@@ -597,16 +631,32 @@ def password(self, request, data, client):
597631
"""
598632

599633
data = self.get_password_grant(request, data, client)
600-
user = data.get('user')
601-
scope = data.get('scope')
634+
kwargs = {
635+
'request': request,
636+
'user': data.get('user'),
637+
'scope': data.get('scope'),
638+
'client': client,
639+
'reuse_existing_access_token': constants.SINGLE_ACCESS_TOKEN,
602640

603-
if constants.SINGLE_ACCESS_TOKEN:
604-
at = self.get_access_token(request, user, scope, client)
605-
else:
606-
at = self.create_access_token(request, user, scope, client)
607641
# Public clients don't get refresh tokens
608-
if client.client_type == constants.CONFIDENTIAL:
609-
rt = self.create_refresh_token(request, user, scope, at, client)
642+
'create_refresh_token': client.client_type == constants.CONFIDENTIAL
643+
}
644+
at, rt = self.get_access_and_refresh_tokens(**kwargs)
645+
646+
return self.access_token_response(at)
647+
648+
def client_credentials(self, request, data, client):
649+
""" Handle ``grant_type=client_credentials`` requests as defined in :rfc:`4.4`. """
650+
data = self.get_client_credentials_grant(request, data, client)
651+
kwargs = {
652+
'request': request,
653+
'user': client.user,
654+
'scope': data.get('scope'),
655+
'client': client,
656+
'reuse_existing_access_token': constants.SINGLE_ACCESS_TOKEN,
657+
'create_refresh_token': False,
658+
}
659+
at, rt = self.get_access_and_refresh_tokens(**kwargs)
610660

611661
return self.access_token_response(at)
612662

@@ -622,6 +672,8 @@ def get_handler(self, grant_type):
622672
return self.refresh_token
623673
elif grant_type == 'password':
624674
return self.password
675+
elif grant_type == 'client_credentials':
676+
return self.client_credentials
625677
return None
626678

627679
def get(self, request):

0 commit comments

Comments
 (0)