providers/oauth2: rewrite introspection endpoint to allow basic or bearer auth
This commit is contained in:
parent
553f184aad
commit
8f4e954160
|
@ -4,14 +4,14 @@ from time import sleep
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
from unittest.case import skipUnless
|
from unittest.case import skipUnless
|
||||||
|
|
||||||
|
from channels.testing import ChannelsLiveServerTestCase
|
||||||
from docker.client import DockerClient, from_env
|
from docker.client import DockerClient, from_env
|
||||||
from docker.models.containers import Container
|
from docker.models.containers import Container
|
||||||
from selenium.webdriver.common.by import By
|
from selenium.webdriver.common.by import By
|
||||||
from selenium.webdriver.common.keys import Keys
|
from selenium.webdriver.common.keys import Keys
|
||||||
from channels.testing import ChannelsLiveServerTestCase
|
|
||||||
|
|
||||||
from passbook import __version__
|
|
||||||
from e2e.utils import USER, SeleniumTestCase
|
from e2e.utils import USER, SeleniumTestCase
|
||||||
|
from passbook import __version__
|
||||||
from passbook.core.models import Application
|
from passbook.core.models import Application
|
||||||
from passbook.flows.models import Flow
|
from passbook.flows.models import Flow
|
||||||
from passbook.outposts.models import Outpost, OutpostDeploymentType, OutpostType
|
from passbook.outposts.models import Outpost, OutpostDeploymentType, OutpostType
|
||||||
|
@ -124,6 +124,7 @@ class TestProviderProxyConnect(ChannelsLiveServerTestCase):
|
||||||
return container
|
return container
|
||||||
|
|
||||||
def test_proxy_connectivity(self):
|
def test_proxy_connectivity(self):
|
||||||
|
"""Test proxy connectivity over websocket"""
|
||||||
SeleniumTestCase().apply_default_data()
|
SeleniumTestCase().apply_default_data()
|
||||||
proxy: ProxyProvider = ProxyProvider.objects.create(
|
proxy: ProxyProvider = ProxyProvider.objects.create(
|
||||||
name="proxy_provider",
|
name="proxy_provider",
|
||||||
|
|
|
@ -7,7 +7,6 @@ PROMPT_CONSNET = "consent"
|
||||||
SCOPE_OPENID = "openid"
|
SCOPE_OPENID = "openid"
|
||||||
SCOPE_OPENID_PROFILE = "profile"
|
SCOPE_OPENID_PROFILE = "profile"
|
||||||
SCOPE_OPENID_EMAIL = "email"
|
SCOPE_OPENID_EMAIL = "email"
|
||||||
SCOPE_OPENID_INTROSPECTION = "token_introspection"
|
|
||||||
|
|
||||||
# Read/write full user (including email)
|
# Read/write full user (including email)
|
||||||
SCOPE_GITHUB_USER = "user"
|
SCOPE_GITHUB_USER = "user"
|
||||||
|
|
|
@ -202,11 +202,6 @@ class OAuth2Provider(Provider):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
|
||||||
def scope_names(self) -> List[str]:
|
|
||||||
"""Return list of assigned scopes seperated with a space"""
|
|
||||||
return [pm.scope_name for pm in self.property_mappings.all()]
|
|
||||||
|
|
||||||
def create_refresh_token(
|
def create_refresh_token(
|
||||||
self, user: User, scope: List[str], id_token: Optional["IDToken"] = None
|
self, user: User, scope: List[str], id_token: Optional["IDToken"] = None
|
||||||
) -> "RefreshToken":
|
) -> "RefreshToken":
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
import re
|
import re
|
||||||
from base64 import b64decode
|
from base64 import b64decode
|
||||||
from binascii import Error
|
from binascii import Error
|
||||||
from typing import List, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from django.http import HttpRequest, HttpResponse, JsonResponse
|
from django.http import HttpRequest, HttpResponse, JsonResponse
|
||||||
from django.utils.cache import patch_vary_headers
|
from django.utils.cache import patch_vary_headers
|
||||||
|
@ -50,7 +50,7 @@ def cors_allow_any(request, response):
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
def extract_access_token(request: HttpRequest) -> str:
|
def extract_access_token(request: HttpRequest) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Get the access token using Authorization Request Header Field method.
|
Get the access token using Authorization Request Header Field method.
|
||||||
Or try getting via GET.
|
Or try getting via GET.
|
||||||
|
@ -66,7 +66,7 @@ def extract_access_token(request: HttpRequest) -> str:
|
||||||
return request.POST.get("access_token")
|
return request.POST.get("access_token")
|
||||||
if "access_token" in request.GET:
|
if "access_token" in request.GET:
|
||||||
return request.GET.get("access_token")
|
return request.GET.get("access_token")
|
||||||
return ""
|
return None
|
||||||
|
|
||||||
|
|
||||||
def extract_client_auth(request: HttpRequest) -> Tuple[str, str]:
|
def extract_client_auth(request: HttpRequest) -> Tuple[str, str]:
|
||||||
|
@ -103,9 +103,12 @@ def protected_resource_view(scopes: List[str]):
|
||||||
|
|
||||||
def wrapper(view):
|
def wrapper(view):
|
||||||
def view_wrapper(request, *args, **kwargs):
|
def view_wrapper(request, *args, **kwargs):
|
||||||
access_token = extract_access_token(request)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
access_token = extract_access_token(request)
|
||||||
|
if not access_token:
|
||||||
|
LOGGER.debug("No token passed")
|
||||||
|
raise BearerTokenError("invalid_token")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
kwargs["token"] = RefreshToken.objects.get(
|
kwargs["token"] = RefreshToken.objects.get(
|
||||||
access_token=access_token
|
access_token=access_token
|
||||||
|
|
|
@ -1,15 +1,17 @@
|
||||||
"""passbook OAuth2 Token Introspection Views"""
|
"""passbook OAuth2 Token Introspection Views"""
|
||||||
from dataclasses import InitVar, dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from django.http import HttpRequest, HttpResponse
|
from django.http import HttpRequest, HttpResponse
|
||||||
from django.views import View
|
from django.views import View
|
||||||
from structlog import get_logger
|
from structlog import get_logger
|
||||||
|
|
||||||
from passbook.providers.oauth2.constants import SCOPE_OPENID_INTROSPECTION
|
|
||||||
from passbook.providers.oauth2.errors import TokenIntrospectionError
|
from passbook.providers.oauth2.errors import TokenIntrospectionError
|
||||||
from passbook.providers.oauth2.models import IDToken, OAuth2Provider, RefreshToken
|
from passbook.providers.oauth2.models import IDToken, OAuth2Provider, RefreshToken
|
||||||
from passbook.providers.oauth2.utils import TokenResponse, extract_client_auth
|
from passbook.providers.oauth2.utils import (
|
||||||
|
TokenResponse,
|
||||||
|
extract_access_token,
|
||||||
|
extract_client_auth,
|
||||||
|
)
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
|
|
||||||
|
@ -18,39 +20,17 @@ LOGGER = get_logger()
|
||||||
class TokenIntrospectionParams:
|
class TokenIntrospectionParams:
|
||||||
"""Parameters for Token Introspection"""
|
"""Parameters for Token Introspection"""
|
||||||
|
|
||||||
client_id: str
|
token: RefreshToken
|
||||||
client_secret: str
|
|
||||||
|
|
||||||
raw_token: InitVar[str]
|
provider: OAuth2Provider = field(init=False)
|
||||||
|
id_token: IDToken = field(init=False)
|
||||||
|
|
||||||
token: Optional[RefreshToken] = None
|
def __post_init__(self):
|
||||||
|
|
||||||
provider: Optional[OAuth2Provider] = None
|
|
||||||
id_token: Optional[IDToken] = None
|
|
||||||
|
|
||||||
def __post_init__(self, raw_token: str):
|
|
||||||
try:
|
|
||||||
self.token = RefreshToken.objects.get(access_token=raw_token)
|
|
||||||
except RefreshToken.DoesNotExist:
|
|
||||||
LOGGER.debug("Token does not exist", token=raw_token)
|
|
||||||
raise TokenIntrospectionError()
|
|
||||||
if self.token.is_expired:
|
if self.token.is_expired:
|
||||||
LOGGER.debug("Token is not valid", token=raw_token)
|
LOGGER.debug("Token is not valid")
|
||||||
raise TokenIntrospectionError()
|
|
||||||
try:
|
|
||||||
self.provider = OAuth2Provider.objects.get(
|
|
||||||
client_id=self.client_id, client_secret=self.client_secret,
|
|
||||||
)
|
|
||||||
except OAuth2Provider.DoesNotExist:
|
|
||||||
LOGGER.debug("provider for ID not found", client_id=self.client_id)
|
|
||||||
raise TokenIntrospectionError()
|
|
||||||
if SCOPE_OPENID_INTROSPECTION not in self.provider.scope_names:
|
|
||||||
LOGGER.debug(
|
|
||||||
"OAuth2Provider does not have introspection scope",
|
|
||||||
client_id=self.client_id,
|
|
||||||
)
|
|
||||||
raise TokenIntrospectionError()
|
raise TokenIntrospectionError()
|
||||||
|
|
||||||
|
self.provider = self.token.provider
|
||||||
self.id_token = self.token.id_token
|
self.id_token = self.token.id_token
|
||||||
|
|
||||||
if not self.token.id_token:
|
if not self.token.id_token:
|
||||||
|
@ -59,31 +39,61 @@ class TokenIntrospectionParams:
|
||||||
)
|
)
|
||||||
raise TokenIntrospectionError()
|
raise TokenIntrospectionError()
|
||||||
|
|
||||||
audience = self.token.id_token.aud
|
def authenticate_basic(self, request: HttpRequest) -> bool:
|
||||||
if not audience:
|
"""Attempt to authenticate via Basic auth of client_id:client_secret"""
|
||||||
LOGGER.debug(
|
client_id, client_secret = extract_client_auth(request)
|
||||||
"No audience found for token", token=self.token,
|
if client_id == client_secret == "":
|
||||||
)
|
return False
|
||||||
|
if (
|
||||||
|
client_id != self.provider.client_id
|
||||||
|
or client_secret != self.provider.client_secret
|
||||||
|
):
|
||||||
|
LOGGER.debug("(basic) Provider for basic auth does not exist")
|
||||||
raise TokenIntrospectionError()
|
raise TokenIntrospectionError()
|
||||||
|
return True
|
||||||
|
|
||||||
if audience not in self.provider.scope_names:
|
def authenticate_bearer(self, request: HttpRequest) -> bool:
|
||||||
LOGGER.debug(
|
"""Attempt to authenticate via token sent as bearer header"""
|
||||||
"provider does not audience scope",
|
body_token = extract_access_token(request)
|
||||||
client_id=self.client_id,
|
if not body_token:
|
||||||
audience=audience,
|
return False
|
||||||
)
|
tokens = RefreshToken.objects.filter(access_token=body_token).select_related(
|
||||||
|
"provider"
|
||||||
|
)
|
||||||
|
if not tokens.exists():
|
||||||
|
LOGGER.debug("(bearer) Token does not exist")
|
||||||
raise TokenIntrospectionError()
|
raise TokenIntrospectionError()
|
||||||
|
if tokens.first().provider != self.provider:
|
||||||
|
LOGGER.debug("(bearer) Token providers don't match")
|
||||||
|
raise TokenIntrospectionError()
|
||||||
|
return True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_request(request: HttpRequest) -> "TokenIntrospectionParams":
|
def from_request(request: HttpRequest) -> "TokenIntrospectionParams":
|
||||||
"""Extract required Parameters from HTTP Request"""
|
"""Extract required Parameters from HTTP Request"""
|
||||||
# Introspection only supports POST requests
|
raw_token = request.POST.get("token")
|
||||||
client_id, client_secret = extract_client_auth(request)
|
token_type_hint = request.POST.get("token_type_hint", "access_token")
|
||||||
return TokenIntrospectionParams(
|
token_filter = {token_type_hint: raw_token}
|
||||||
raw_token=request.POST.get("token"),
|
|
||||||
client_id=client_id,
|
if token_type_hint not in ["access_token", "refresh_token"]:
|
||||||
client_secret=client_secret,
|
LOGGER.debug("token_type_hint has invalid value", value=token_type_hint)
|
||||||
)
|
raise TokenIntrospectionError()
|
||||||
|
|
||||||
|
try:
|
||||||
|
token: RefreshToken = RefreshToken.objects.select_related("provider").get(
|
||||||
|
**token_filter
|
||||||
|
)
|
||||||
|
except RefreshToken.DoesNotExist:
|
||||||
|
LOGGER.debug("Token does not exist", token=raw_token)
|
||||||
|
raise TokenIntrospectionError()
|
||||||
|
|
||||||
|
params = TokenIntrospectionParams(token=token)
|
||||||
|
if not any(
|
||||||
|
[params.authenticate_basic(request), params.authenticate_bearer(request)]
|
||||||
|
):
|
||||||
|
LOGGER.debug("Not authenticated")
|
||||||
|
raise TokenIntrospectionError()
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
class TokenIntrospectionView(View):
|
class TokenIntrospectionView(View):
|
||||||
|
@ -101,12 +111,12 @@ class TokenIntrospectionView(View):
|
||||||
self.params = TokenIntrospectionParams.from_request(request)
|
self.params = TokenIntrospectionParams.from_request(request)
|
||||||
|
|
||||||
response_dic = {}
|
response_dic = {}
|
||||||
if self.id_token:
|
if self.params.id_token:
|
||||||
token_dict = self.id_token.to_dict()
|
token_dict = self.params.id_token.to_dict()
|
||||||
for k in ("aud", "sub", "exp", "iat", "iss"):
|
for k in ("aud", "sub", "exp", "iat", "iss"):
|
||||||
response_dic[k] = token_dict[k]
|
response_dic[k] = token_dict[k]
|
||||||
response_dic["active"] = True
|
response_dic["active"] = True
|
||||||
response_dic["client_id"] = self.token.provider.client_id
|
response_dic["client_id"] = self.params.token.provider.client_id
|
||||||
|
|
||||||
return TokenResponse(response_dic)
|
return TokenResponse(response_dic)
|
||||||
except TokenIntrospectionError:
|
except TokenIntrospectionError:
|
||||||
|
|
Reference in a new issue