providers/oauth2: rewrite introspection endpoint to allow basic or bearer auth

This commit is contained in:
Jens Langhammer 2020-09-28 11:42:27 +02:00
parent 553f184aad
commit 8f4e954160
5 changed files with 74 additions and 66 deletions

View file

@ -4,14 +4,14 @@ from time import sleep
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from unittest.case import skipUnless from unittest.case import skipUnless
from channels.testing import ChannelsLiveServerTestCase
from docker.client import DockerClient, from_env from docker.client import DockerClient, from_env
from docker.models.containers import Container from docker.models.containers import Container
from selenium.webdriver.common.by import By from selenium.webdriver.common.by import By
from selenium.webdriver.common.keys import Keys from selenium.webdriver.common.keys import Keys
from channels.testing import ChannelsLiveServerTestCase
from passbook import __version__
from e2e.utils import USER, SeleniumTestCase from e2e.utils import USER, SeleniumTestCase
from passbook import __version__
from passbook.core.models import Application from passbook.core.models import Application
from passbook.flows.models import Flow from passbook.flows.models import Flow
from passbook.outposts.models import Outpost, OutpostDeploymentType, OutpostType from passbook.outposts.models import Outpost, OutpostDeploymentType, OutpostType
@ -124,6 +124,7 @@ class TestProviderProxyConnect(ChannelsLiveServerTestCase):
return container return container
def test_proxy_connectivity(self): def test_proxy_connectivity(self):
"""Test proxy connectivity over websocket"""
SeleniumTestCase().apply_default_data() SeleniumTestCase().apply_default_data()
proxy: ProxyProvider = ProxyProvider.objects.create( proxy: ProxyProvider = ProxyProvider.objects.create(
name="proxy_provider", name="proxy_provider",

View file

@ -7,7 +7,6 @@ PROMPT_CONSNET = "consent"
SCOPE_OPENID = "openid" SCOPE_OPENID = "openid"
SCOPE_OPENID_PROFILE = "profile" SCOPE_OPENID_PROFILE = "profile"
SCOPE_OPENID_EMAIL = "email" SCOPE_OPENID_EMAIL = "email"
SCOPE_OPENID_INTROSPECTION = "token_introspection"
# Read/write full user (including email) # Read/write full user (including email)
SCOPE_GITHUB_USER = "user" SCOPE_GITHUB_USER = "user"

View file

@ -202,11 +202,6 @@ class OAuth2Provider(Provider):
), ),
) )
@property
def scope_names(self) -> List[str]:
"""Return list of assigned scopes seperated with a space"""
return [pm.scope_name for pm in self.property_mappings.all()]
def create_refresh_token( def create_refresh_token(
self, user: User, scope: List[str], id_token: Optional["IDToken"] = None self, user: User, scope: List[str], id_token: Optional["IDToken"] = None
) -> "RefreshToken": ) -> "RefreshToken":

View file

@ -2,7 +2,7 @@
import re import re
from base64 import b64decode from base64 import b64decode
from binascii import Error from binascii import Error
from typing import List, Tuple from typing import List, Optional, Tuple
from django.http import HttpRequest, HttpResponse, JsonResponse from django.http import HttpRequest, HttpResponse, JsonResponse
from django.utils.cache import patch_vary_headers from django.utils.cache import patch_vary_headers
@ -50,7 +50,7 @@ def cors_allow_any(request, response):
return response return response
def extract_access_token(request: HttpRequest) -> str: def extract_access_token(request: HttpRequest) -> Optional[str]:
""" """
Get the access token using Authorization Request Header Field method. Get the access token using Authorization Request Header Field method.
Or try getting via GET. Or try getting via GET.
@ -66,7 +66,7 @@ def extract_access_token(request: HttpRequest) -> str:
return request.POST.get("access_token") return request.POST.get("access_token")
if "access_token" in request.GET: if "access_token" in request.GET:
return request.GET.get("access_token") return request.GET.get("access_token")
return "" return None
def extract_client_auth(request: HttpRequest) -> Tuple[str, str]: def extract_client_auth(request: HttpRequest) -> Tuple[str, str]:
@ -103,9 +103,12 @@ def protected_resource_view(scopes: List[str]):
def wrapper(view): def wrapper(view):
def view_wrapper(request, *args, **kwargs): def view_wrapper(request, *args, **kwargs):
access_token = extract_access_token(request)
try: try:
access_token = extract_access_token(request)
if not access_token:
LOGGER.debug("No token passed")
raise BearerTokenError("invalid_token")
try: try:
kwargs["token"] = RefreshToken.objects.get( kwargs["token"] = RefreshToken.objects.get(
access_token=access_token access_token=access_token

View file

@ -1,15 +1,17 @@
"""passbook OAuth2 Token Introspection Views""" """passbook OAuth2 Token Introspection Views"""
from dataclasses import InitVar, dataclass from dataclasses import dataclass, field
from typing import Optional
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
from django.views import View from django.views import View
from structlog import get_logger from structlog import get_logger
from passbook.providers.oauth2.constants import SCOPE_OPENID_INTROSPECTION
from passbook.providers.oauth2.errors import TokenIntrospectionError from passbook.providers.oauth2.errors import TokenIntrospectionError
from passbook.providers.oauth2.models import IDToken, OAuth2Provider, RefreshToken from passbook.providers.oauth2.models import IDToken, OAuth2Provider, RefreshToken
from passbook.providers.oauth2.utils import TokenResponse, extract_client_auth from passbook.providers.oauth2.utils import (
TokenResponse,
extract_access_token,
extract_client_auth,
)
LOGGER = get_logger() LOGGER = get_logger()
@ -18,39 +20,17 @@ LOGGER = get_logger()
class TokenIntrospectionParams: class TokenIntrospectionParams:
"""Parameters for Token Introspection""" """Parameters for Token Introspection"""
client_id: str token: RefreshToken
client_secret: str
raw_token: InitVar[str] provider: OAuth2Provider = field(init=False)
id_token: IDToken = field(init=False)
token: Optional[RefreshToken] = None def __post_init__(self):
provider: Optional[OAuth2Provider] = None
id_token: Optional[IDToken] = None
def __post_init__(self, raw_token: str):
try:
self.token = RefreshToken.objects.get(access_token=raw_token)
except RefreshToken.DoesNotExist:
LOGGER.debug("Token does not exist", token=raw_token)
raise TokenIntrospectionError()
if self.token.is_expired: if self.token.is_expired:
LOGGER.debug("Token is not valid", token=raw_token) LOGGER.debug("Token is not valid")
raise TokenIntrospectionError()
try:
self.provider = OAuth2Provider.objects.get(
client_id=self.client_id, client_secret=self.client_secret,
)
except OAuth2Provider.DoesNotExist:
LOGGER.debug("provider for ID not found", client_id=self.client_id)
raise TokenIntrospectionError()
if SCOPE_OPENID_INTROSPECTION not in self.provider.scope_names:
LOGGER.debug(
"OAuth2Provider does not have introspection scope",
client_id=self.client_id,
)
raise TokenIntrospectionError() raise TokenIntrospectionError()
self.provider = self.token.provider
self.id_token = self.token.id_token self.id_token = self.token.id_token
if not self.token.id_token: if not self.token.id_token:
@ -59,31 +39,61 @@ class TokenIntrospectionParams:
) )
raise TokenIntrospectionError() raise TokenIntrospectionError()
audience = self.token.id_token.aud def authenticate_basic(self, request: HttpRequest) -> bool:
if not audience: """Attempt to authenticate via Basic auth of client_id:client_secret"""
LOGGER.debug( client_id, client_secret = extract_client_auth(request)
"No audience found for token", token=self.token, if client_id == client_secret == "":
) return False
if (
client_id != self.provider.client_id
or client_secret != self.provider.client_secret
):
LOGGER.debug("(basic) Provider for basic auth does not exist")
raise TokenIntrospectionError() raise TokenIntrospectionError()
return True
if audience not in self.provider.scope_names: def authenticate_bearer(self, request: HttpRequest) -> bool:
LOGGER.debug( """Attempt to authenticate via token sent as bearer header"""
"provider does not audience scope", body_token = extract_access_token(request)
client_id=self.client_id, if not body_token:
audience=audience, return False
) tokens = RefreshToken.objects.filter(access_token=body_token).select_related(
"provider"
)
if not tokens.exists():
LOGGER.debug("(bearer) Token does not exist")
raise TokenIntrospectionError() raise TokenIntrospectionError()
if tokens.first().provider != self.provider:
LOGGER.debug("(bearer) Token providers don't match")
raise TokenIntrospectionError()
return True
@staticmethod @staticmethod
def from_request(request: HttpRequest) -> "TokenIntrospectionParams": def from_request(request: HttpRequest) -> "TokenIntrospectionParams":
"""Extract required Parameters from HTTP Request""" """Extract required Parameters from HTTP Request"""
# Introspection only supports POST requests raw_token = request.POST.get("token")
client_id, client_secret = extract_client_auth(request) token_type_hint = request.POST.get("token_type_hint", "access_token")
return TokenIntrospectionParams( token_filter = {token_type_hint: raw_token}
raw_token=request.POST.get("token"),
client_id=client_id, if token_type_hint not in ["access_token", "refresh_token"]:
client_secret=client_secret, LOGGER.debug("token_type_hint has invalid value", value=token_type_hint)
) raise TokenIntrospectionError()
try:
token: RefreshToken = RefreshToken.objects.select_related("provider").get(
**token_filter
)
except RefreshToken.DoesNotExist:
LOGGER.debug("Token does not exist", token=raw_token)
raise TokenIntrospectionError()
params = TokenIntrospectionParams(token=token)
if not any(
[params.authenticate_basic(request), params.authenticate_bearer(request)]
):
LOGGER.debug("Not authenticated")
raise TokenIntrospectionError()
return params
class TokenIntrospectionView(View): class TokenIntrospectionView(View):
@ -101,12 +111,12 @@ class TokenIntrospectionView(View):
self.params = TokenIntrospectionParams.from_request(request) self.params = TokenIntrospectionParams.from_request(request)
response_dic = {} response_dic = {}
if self.id_token: if self.params.id_token:
token_dict = self.id_token.to_dict() token_dict = self.params.id_token.to_dict()
for k in ("aud", "sub", "exp", "iat", "iss"): for k in ("aud", "sub", "exp", "iat", "iss"):
response_dic[k] = token_dict[k] response_dic[k] = token_dict[k]
response_dic["active"] = True response_dic["active"] = True
response_dic["client_id"] = self.token.provider.client_id response_dic["client_id"] = self.params.token.provider.client_id
return TokenResponse(response_dic) return TokenResponse(response_dic)
except TokenIntrospectionError: except TokenIntrospectionError: