sources/oauth: fix OAuth1 not working, cleanup

This commit is contained in:
Jens Langhammer 2020-09-26 01:26:06 +02:00
parent d9c2b32cba
commit 7d533889bc
6 changed files with 51 additions and 57 deletions

View file

@ -1,5 +1,5 @@
"""OAuth Clients""" """OAuth Clients"""
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional
from urllib.parse import urlencode from urllib.parse import urlencode
from django.http import HttpRequest from django.http import HttpRequest
@ -18,18 +18,16 @@ class BaseOAuthClient:
"""Base OAuth Client""" """Base OAuth Client"""
session: Session session: Session
source: OAuthSource source: OAuthSource
token: str
request: HttpRequest request: HttpRequest
callback: Optional[str] callback: Optional[str]
def __init__( def __init__(
self, source: OAuthSource, request: HttpRequest, callback: Optional[str] = None self, source: OAuthSource, request: HttpRequest, callback: Optional[str] = None
): ):
self.source = source self.source = source
self.token = ""
self.session = Session() self.session = Session()
self.request = request self.request = request
self.callback = callback self.callback = callback
@ -42,12 +40,7 @@ class BaseOAuthClient:
def get_profile_info(self, token: Dict[str, str]) -> Optional[Dict[str, Any]]: def get_profile_info(self, token: Dict[str, str]) -> Optional[Dict[str, Any]]:
"Fetch user profile information." "Fetch user profile information."
try: try:
headers = { response = self.do_request("get", self.source.profile_url, token=token,)
"Authorization": f"{token['token_type']} {token['access_token']}"
}
response = self.session.request(
"get", self.source.profile_url, headers=headers,
)
response.raise_for_status() response.raise_for_status()
except RequestException as exc: except RequestException as exc:
LOGGER.warning("Unable to fetch user profile", exc=exc) LOGGER.warning("Unable to fetch user profile", exc=exc)
@ -68,7 +61,7 @@ class BaseOAuthClient:
LOGGER.info("redirect args", **args) LOGGER.info("redirect args", **args)
return f"{self.source.authorization_url}?{params}" return f"{self.source.authorization_url}?{params}"
def parse_raw_token(self, raw_token: str) -> Tuple[str, Optional[str]]: def parse_raw_token(self, raw_token: str) -> Dict[str, Any]:
"Parse token and secret from raw token response." "Parse token and secret from raw token response."
raise NotImplementedError("Defined in a sub-class") # pragma: no cover raise NotImplementedError("Defined in a sub-class") # pragma: no cover

View file

@ -1,10 +1,7 @@
"""OAuth Clients""" """OAuth 1 Clients"""
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional
from urllib.parse import parse_qs from urllib.parse import parse_qsl
from django.db.models.expressions import Value
from django.http import HttpRequest
from django.utils.encoding import force_str
from requests.exceptions import RequestException from requests.exceptions import RequestException
from requests.models import Response from requests.models import Response
from requests_oauthlib import OAuth1 from requests_oauthlib import OAuth1
@ -32,13 +29,14 @@ class OAuthClient(BaseOAuthClient):
data = { data = {
"oauth_verifier": verifier, "oauth_verifier": verifier,
"oauth_callback": self.callback, "oauth_callback": self.callback,
"token": raw_token,
} }
token = self.parse_raw_token(raw_token)
try: try:
response = self.session.request( response = self.do_request(
"post", "post",
self.source.access_token_url, self.source.access_token_url,
data=data, data=data,
token=token,
headers=self._default_headers, headers=self._default_headers,
) )
response.raise_for_status() response.raise_for_status()
@ -46,20 +44,17 @@ class OAuthClient(BaseOAuthClient):
LOGGER.warning("Unable to fetch access token", exc=exc) LOGGER.warning("Unable to fetch access token", exc=exc)
return None return None
else: else:
return response.json() return self.parse_raw_token(response.text)
return None return None
def get_request_token(self) -> str: def get_request_token(self) -> str:
"Fetch the OAuth request token. Only required for OAuth 1.0." "Fetch the OAuth request token. Only required for OAuth 1.0."
callback = self.request.build_absolute_uri(self.callback) callback = self.request.build_absolute_uri(self.callback)
try: try:
response = self.session.request( response = self.do_request(
"post", "post",
self.source.request_token_url, self.source.request_token_url,
data={ data={"oauth_callback": callback},
"oauth_callback": callback,
"oauth_consumer_key": self.source.consumer_key,
},
headers=self._default_headers, headers=self._default_headers,
) )
response.raise_for_status() response.raise_for_status()
@ -72,29 +67,34 @@ class OAuthClient(BaseOAuthClient):
"Get request parameters for redirect url." "Get request parameters for redirect url."
callback = self.request.build_absolute_uri(self.callback) callback = self.request.build_absolute_uri(self.callback)
raw_token = self.get_request_token() raw_token = self.get_request_token()
token, _ = self.parse_raw_token(raw_token) token = self.parse_raw_token(raw_token)
self.request.session[self.session_key] = raw_token self.request.session[self.session_key] = raw_token
return { return {
"oauth_token": token, "oauth_token": token["oauth_token"],
"oauth_callback": callback, "oauth_callback": callback,
} }
def parse_raw_token(self, raw_token: str) -> Tuple[str, Optional[str]]: def parse_raw_token(self, raw_token: str) -> Dict[str, Any]:
"Parse token and secret from raw token response." "Parse token and secret from raw token response."
query_string = parse_qs(raw_token) return dict(parse_qsl(raw_token))
token = query_string["oauth_token"][0] # token = query_string["oauth_token"]
secret = query_string["oauth_token_secret"][0] # secret = query_string["oauth_token_secret"]
return (token, secret) # return (token, secret)
def do_request(self, method: str, url: str, **kwargs) -> Response: def do_request(self, method: str, url: str, **kwargs) -> Response:
"Build remote url request. Constructs necessary auth." "Build remote url request. Constructs necessary auth."
user_token = kwargs.pop("token", self.token) resource_owner_key = None
token, secret = self.parse_raw_token(user_token) resource_owner_secret = None
if "token" in kwargs:
user_token: Dict[str, Any] = kwargs.pop("token")
resource_owner_key = user_token["oauth_token"]
resource_owner_secret = user_token["oauth_token_secret"]
callback = kwargs.pop("oauth_callback", None) callback = kwargs.pop("oauth_callback", None)
verifier = kwargs.get("data", {}).pop("oauth_verifier", None) verifier = kwargs.get("data", {}).pop("oauth_verifier", None)
oauth = OAuth1( oauth = OAuth1(
resource_owner_key=token, resource_owner_key=resource_owner_key,
resource_owner_secret=secret, resource_owner_secret=resource_owner_secret,
client_key=self.source.consumer_key, client_key=self.source.consumer_key,
client_secret=self.source.consumer_secret, client_secret=self.source.consumer_secret,
verifier=verifier, verifier=verifier,

View file

@ -1,9 +1,8 @@
"""OAuth Clients""" """OAuth 2 Clients"""
import json from json import loads
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple from typing import Any, Dict, Optional
from urllib.parse import parse_qs from urllib.parse import parse_qsl
from django.http import HttpRequest
from django.utils.crypto import constant_time_compare, get_random_string from django.utils.crypto import constant_time_compare, get_random_string
from requests.exceptions import RequestException from requests.exceptions import RequestException
from requests.models import Response from requests.models import Response
@ -13,8 +12,6 @@ from passbook import __version__
from passbook.sources.oauth.clients.base import BaseOAuthClient from passbook.sources.oauth.clients.base import BaseOAuthClient
LOGGER = get_logger() LOGGER = get_logger()
if TYPE_CHECKING:
from passbook.sources.oauth.models import OAuthSource
class OAuth2Client(BaseOAuthClient): class OAuth2Client(BaseOAuthClient):
@ -88,25 +85,27 @@ class OAuth2Client(BaseOAuthClient):
self.request.session[self.session_key] = state self.request.session[self.session_key] = state
return args return args
def parse_raw_token(self, raw_token: str) -> Tuple[str, Optional[str]]: def parse_raw_token(self, raw_token: str) -> Dict[str, Any]:
"Parse token and secret from raw token response." "Parse token and secret from raw token response."
# Load as json first then parse as query string # Load as json first then parse as query string
try: try:
token_data = json.loads(raw_token) token_data = loads(raw_token)
except ValueError: except ValueError:
token = parse_qs(raw_token)["access_token"][0] return dict(parse_qsl(raw_token))
else: else:
token = token_data["access_token"] return token_data
return (token, None)
def do_request(self, method: str, url: str, **kwargs) -> Response: def do_request(self, method: str, url: str, **kwargs) -> Response:
"Build remote url request. Constructs necessary auth." "Build remote url request. Constructs necessary auth."
user_token = kwargs.pop("token", self.token) if "token" in kwargs:
token, _ = self.parse_raw_token(user_token) token = self.parse_raw_token(kwargs.pop("token"))
if token is not None:
params = kwargs.get("params", {}) params = kwargs.get("params", {})
params["access_token"] = token params["access_token"] = token["access_token"]
kwargs["params"] = params kwargs["params"] = params
headers = kwargs.get("headers", {})
headers["Authorization"] = f"{token['token_type']} {token['access_token']}"
return super().do_request(method, url, **kwargs) return super().do_request(method, url, **kwargs)
@property @property

View file

@ -1,3 +1,4 @@
"""OAuth Source Exception"""
from passbook.lib.sentry import SentryIgnoredException from passbook.lib.sentry import SentryIgnoredException

View file

@ -24,7 +24,7 @@ class RedditOAuthRedirect(OAuthRedirect):
class RedditOAuth2Client(OAuth2Client): class RedditOAuth2Client(OAuth2Client):
"""Reddit OAuth2 Client""" """Reddit OAuth2 Client"""
def get_access_token(self, request, callback=None, **request_kwargs): def get_access_token(self, **request_kwargs):
"Fetch access token from callback request." "Fetch access token from callback request."
auth = HTTPBasicAuth(self.source.consumer_key, self.source.consumer_secret) auth = HTTPBasicAuth(self.source.consumer_key, self.source.consumer_secret)
return super().get_access_token(auth=auth) return super().get_access_token(auth=auth)

View file

@ -51,10 +51,11 @@ class OAuthCallback(OAuthClientMixin, View):
if not self.source.enabled: if not self.source.enabled:
raise Http404(f"Source {slug} is not enabled.") raise Http404(f"Source {slug} is not enabled.")
client = self.get_client(self.source) client = self.get_client(
callback = self.get_callback_url(self.source) self.source, callback=self.get_callback_url(self.source)
)
# Fetch access token # Fetch access token
token = client.get_access_token(callback=callback) token = client.get_access_token()
if token is None: if token is None:
return self.handle_login_failure(self.source, "Could not retrieve token.") return self.handle_login_failure(self.source, "Could not retrieve token.")
if "error" in token: if "error" in token: