diff --git a/passbook/oauth_client/source_types/discord.py b/passbook/oauth_client/source_types/discord.py index 6052e519e..4a0daad17 100644 --- a/passbook/oauth_client/source_types/discord.py +++ b/passbook/oauth_client/source_types/discord.py @@ -2,7 +2,6 @@ import json from logging import getLogger -from django.contrib.auth import get_user_model from requests.exceptions import RequestException from passbook.oauth_client.clients import OAuth2Client @@ -50,12 +49,11 @@ class DiscordOAuth2Callback(OAuthCallback): client_class = DiscordOAuth2Client def get_or_create_user(self, source, access, info): - user = get_user_model() user_data = { - user.USERNAME_FIELD: info.get('username'), + 'username': info.get('username'), 'email': info.get('email', 'None'), 'first_name': info.get('username'), 'password': None, } - discord_user = user_get_or_create(user_model=user, **user_data) + discord_user = user_get_or_create(**user_data) return discord_user diff --git a/passbook/oauth_client/source_types/facebook.py b/passbook/oauth_client/source_types/facebook.py index 143d32bc4..57a3dc277 100644 --- a/passbook/oauth_client/source_types/facebook.py +++ b/passbook/oauth_client/source_types/facebook.py @@ -1,7 +1,5 @@ """Facebook OAuth Views""" -from django.contrib.auth import get_user_model - from passbook.oauth_client.source_types.manager import MANAGER, RequestKind from passbook.oauth_client.utils import user_get_or_create from passbook.oauth_client.views.core import OAuthCallback, OAuthRedirect @@ -22,12 +20,11 @@ class FacebookOAuth2Callback(OAuthCallback): """Facebook OAuth2 Callback""" def get_or_create_user(self, source, access, info): - user = get_user_model() user_data = { - user.USERNAME_FIELD: info.get('name'), + 'username': info.get('name'), 'email': info.get('email', ''), 'first_name': info.get('name'), 'password': None, } - fb_user = user_get_or_create(user_model=user, **user_data) + fb_user = user_get_or_create(**user_data) return fb_user diff --git a/passbook/oauth_client/source_types/github.py b/passbook/oauth_client/source_types/github.py index 38c2bb288..fb75fe506 100644 --- a/passbook/oauth_client/source_types/github.py +++ b/passbook/oauth_client/source_types/github.py @@ -1,7 +1,5 @@ """GitHub OAuth Views""" -from django.contrib.auth import get_user_model - from passbook.oauth_client.source_types.manager import MANAGER, RequestKind from passbook.oauth_client.utils import user_get_or_create from passbook.oauth_client.views.core import OAuthCallback @@ -12,12 +10,11 @@ class GitHubOAuth2Callback(OAuthCallback): """GitHub OAuth2 Callback""" def get_or_create_user(self, source, access, info): - user = get_user_model() user_data = { - user.USERNAME_FIELD: info.get('login'), + 'username': info.get('login'), 'email': info.get('email', ''), 'first_name': info.get('name'), 'password': None, } - gh_user = user_get_or_create(user_model=user, **user_data) + gh_user = user_get_or_create(**user_data) return gh_user diff --git a/passbook/oauth_client/source_types/google.py b/passbook/oauth_client/source_types/google.py index c739c1156..e92731b2e 100644 --- a/passbook/oauth_client/source_types/google.py +++ b/passbook/oauth_client/source_types/google.py @@ -1,6 +1,4 @@ """Google OAuth Views""" -from django.contrib.auth import get_user_model - from passbook.oauth_client.source_types.manager import MANAGER, RequestKind from passbook.oauth_client.utils import user_get_or_create from passbook.oauth_client.views.core import OAuthCallback, OAuthRedirect @@ -21,12 +19,11 @@ class GoogleOAuth2Callback(OAuthCallback): """Google OAuth2 Callback""" def get_or_create_user(self, source, access, info): - user = get_user_model() user_data = { - user.USERNAME_FIELD: info.get('email'), + 'username': info.get('email'), 'email': info.get('email', ''), 'first_name': info.get('name'), 'password': None, } - google_user = user_get_or_create(user_model=user, **user_data) + google_user = user_get_or_create(**user_data) return google_user diff --git a/passbook/oauth_client/source_types/reddit.py b/passbook/oauth_client/source_types/reddit.py index dbdd8c375..87f175eb5 100644 --- a/passbook/oauth_client/source_types/reddit.py +++ b/passbook/oauth_client/source_types/reddit.py @@ -2,7 +2,6 @@ import json from logging import getLogger -from django.contrib.auth import get_user_model from requests.auth import HTTPBasicAuth from requests.exceptions import RequestException @@ -59,12 +58,11 @@ class RedditOAuth2Callback(OAuthCallback): client_class = RedditOAuth2Client def get_or_create_user(self, source, access, info): - user = get_user_model() user_data = { - user.USERNAME_FIELD: info.get('name'), + 'username': info.get('name'), 'email': None, 'first_name': info.get('name'), 'password': None, } - reddit_user = user_get_or_create(user_model=user, **user_data) + reddit_user = user_get_or_create(**user_data) return reddit_user diff --git a/passbook/oauth_client/source_types/supervisr.py b/passbook/oauth_client/source_types/supervisr.py index 9272e40f7..a0e702502 100644 --- a/passbook/oauth_client/source_types/supervisr.py +++ b/passbook/oauth_client/source_types/supervisr.py @@ -3,7 +3,6 @@ import json from logging import getLogger -from django.contrib.auth import get_user_model from requests.exceptions import RequestException from passbook.oauth_client.clients import OAuth2Client @@ -44,12 +43,11 @@ class SupervisrOAuthCallback(OAuthCallback): return info['pk'] def get_or_create_user(self, source, access, info): - user = get_user_model() user_data = { - user.USERNAME_FIELD: info.get('username'), + 'username': info.get('username'), 'email': info.get('email', ''), 'first_name': info.get('first_name'), 'password': None, } - sv_user = user_get_or_create(user_model=user, **user_data) + sv_user = user_get_or_create(**user_data) return sv_user diff --git a/passbook/oauth_client/source_types/twitter.py b/passbook/oauth_client/source_types/twitter.py index 60a7d8b2c..449268e93 100644 --- a/passbook/oauth_client/source_types/twitter.py +++ b/passbook/oauth_client/source_types/twitter.py @@ -2,7 +2,6 @@ from logging import getLogger -from django.contrib.auth import get_user_model from requests.exceptions import RequestException from passbook.oauth_client.clients import OAuthClient @@ -36,12 +35,11 @@ class TwitterOAuthCallback(OAuthCallback): client_class = TwitterOAuthClient def get_or_create_user(self, source, access, info): - user = get_user_model() user_data = { - user.USERNAME_FIELD: info.get('screen_name'), + 'username': info.get('screen_name'), 'email': info.get('email', ''), 'first_name': info.get('name'), 'password': None, } - tw_user = user_get_or_create(user_model=user, **user_data) + tw_user = user_get_or_create(**user_data) return tw_user diff --git a/passbook/oauth_client/utils.py b/passbook/oauth_client/utils.py index 91ab2211b..ed5bd3e0e 100644 --- a/passbook/oauth_client/utils.py +++ b/passbook/oauth_client/utils.py @@ -1,16 +1,17 @@ """OAuth Client User Creation Utils""" -from django.contrib.auth import get_user_model from django.db.utils import IntegrityError +from passbook.core.models import User -def user_get_or_create(user_model=None, **kwargs): + +def user_get_or_create(**kwargs): """Create user or return existing user""" - if user_model is None: - user_model = get_user_model() try: - new_user = user_model.objects.create_user(**kwargs) + new_user = User.objects.create_user(**kwargs) except IntegrityError: - # TODO: Fix potential username change vuln - new_user = user_model.objects.get(username=kwargs['username']) + # At this point we've already checked that there is no existing connection + # to any user. Hence if we can't create the user, + kwargs['username'] = '%s_1' % kwargs['username'] + new_user = User.objects.create_user(**kwargs) return new_user diff --git a/passbook/oauth_client/views/core.py b/passbook/oauth_client/views/core.py index fb5819b33..cb5af5102 100644 --- a/passbook/oauth_client/views/core.py +++ b/passbook/oauth_client/views/core.py @@ -113,7 +113,9 @@ class OAuthCallback(OAuthClientMixin, View): ) user = authenticate(source=self.source, identifier=identifier, request=request) if user is None: + LOGGER.debug("Handling new user") return self.handle_new_user(self.source, connection, info) + LOGGER.debug("Handling existing user") return self.handle_existing_user(self.source, user, connection, info) # pylint: disable=unused-argument