# Copyright (c) 2026 Matthew Baker. All rights reserved. Licenced under the Apache Licence 2.0. See LICENSE file
from crossauth_backend.oauth.wellknown import OpenIdConfiguration
from crossauth_backend.utils import set_parameter, ParamType
from crossauth_backend.common.error import CrossauthError, ErrorCode
from crossauth_backend.common.interfaces import KeyPrefix
from crossauth_backend.common.logger import CrossauthLogger, j
from crossauth_backend.crypto import Crypto
from crossauth_backend.storage import KeyStorage
from crossauth_backend.utils import MapGetter
from typing import TypedDict, Optional, Dict, Any, Literal, List
import json
from datetime import datetime
import aiohttp
from jwt import (
JWT,
jwk_from_dict,
jwk_from_pem,
AbstractJWKBase
)
type EncryptionKey = Dict[str, Any] | str
[docs]
class Keys(TypedDict, total=True):
keys: List[AbstractJWKBase]
[docs]
class OAuthTokenConsumerOptions(TypedDict, total=False):
"""
Options that can be passed to: :class:`OAuthTokenConsumerBase`.
"""
jwt_key_type : str
"""
Secret key if using a symmetric cipher for signing the JWT.
Either this or `jwt_secret_key_file` is required when using self kind of
cipher """
jwt_secret_key : str
""" Secret key if using a symmetric cipher for signing the JWT.
Either this or `jwt_secret_key_file` is required when using self kind of
cipher """
jwt_public_key : str
""" The public key if using a public key cipher for signing the JWT.
Either this or `jwt_public_key_file` is required when using self kind of
cipher. privateKey or privateKeyFile is also required."""
clock_tolerance : int
""" Number of seconds tolerance when checking expiration. Default 10"""
auth_server_base_url : str
""" The value to expect in the iss
claim. If the iss does not match self, the token is rejected.
No default (required)"""
oidc_config : Optional[OpenIdConfiguration|Dict[str,Any]]
"""
For initializing the token consumer with a static OpenID Connect
configuration.
"""
persist_access_token : bool
""" Whether to persist access tokens in key storage. Default false.
If you set self to True, you must also set `key_storage`.
"""
key_storage : Optional[KeyStorage]
""" If persisting tokens, you need to provide a storage to persist them to"""
jwt_secret_key_file : str
""" Filename with secret key if using a symmetric cipher for signing the
JWT. Either self or `jwt_secret_key` is required when using self kind
of cipher"""
jwt_public_key_file : str
""" Filename for the public key if using a public key cipher for signing the
JWT. Either self or `jwt_public_key` is required when using self kind of
cipher. privateKey or privateKeyFile is also required."""
audience : str
"""
The aud claim needs to match self value.
No default (required)
"""
[docs]
class OAuthTokenConsumer:
"""
This abstract class is for validating OAuth JWTs.
"""
@property
def auth_server_base_url(self):
return self._auth_server_base_url
@property
def audience(self):
return self._audience
@audience.setter
def audience(self, val : str):
self._audience = val
@property
def _jwt_key_type(self):
return self.__jwt_key_type
@_jwt_key_type.setter
def _jwt_key_type(self, val : str|None):
self.__jwt_key_type = val
@property
def _jwt_public_key_file(self):
return self.__jwt_public_key_file
@_jwt_public_key_file.setter
def _jwt_public_key_file(self, val : str|None):
self.__jwt_public_key_file = val
@property
def _jwt_public_key(self):
return self.__jwt_public_key
@_jwt_public_key.setter
def _jwt_public_key(self, val : bytes|None):
self.__jwt_public_key = val
@property
def _jwt_secret_key(self):
return self.__jwt_secret_key
@_jwt_secret_key.setter
def _jwt_secret_key(self, val : bytes|None):
self.__jwt_secret_key = val
@property
def _clock_tolerance(self):
return self.__clock_tolerance
@_clock_tolerance.setter
def _clock_tolerance(self, val : int):
self.__clock_tolerance = val
@property
def _persist_access_token(self):
return self.__persist_access_token
@_persist_access_token.setter
def _persist_access_token(self, val : bool):
self.__persist_access_token = val
@property
def oidc_config(self):
return self._oidc_config
@oidc_config.setter
def oidc_config(self, val : OpenIdConfiguration|None):
self._oidc_config = val
@property
def keys(self):
return self._keys
@keys.setter
def keys(self, val : Dict[str, AbstractJWKBase]):
self._keys = val
def __init__(self, audience: str, options: OAuthTokenConsumerOptions = {}):
"""
The OpenID Connect configuration for the authorization server,
either passed to the constructor or fetched from the authorization
server.
"""
self._audience = audience
self.__jwt_key_type : str|None = ""
self._auth_server_base_url : str = ""
self.__jwt_secret_key : bytes|None = None
self.__jwt_public_key : bytes|None = None
self.__jwt_secret_key_file : str|None = None
self.__jwt_public_key_file : str|None = None
self.__clock_tolerance = 10
set_parameter("jwt_key_type", ParamType.String, self, options, "JWT_KEY_TYPE")
self.__persist_access_token = False
self._oidc_config = options.get('oidc_config')
self._keys : Dict[str, AbstractJWKBase] = {}
self.__key_storage : KeyStorage|None = None
if (options.get("key_storage") is not None):
self.__key_storage = options.get("key_storage")
set_parameter("auth_server_base_url", ParamType.String, self, options, "AUTH_SERVER_BASE_URL", required=True, protected=True)
set_parameter("jwt_key_type", ParamType.String, self, options, "JWT_KEY_TYPE")
set_parameter("jwt_public_key_file", ParamType.String, self, options, "JWT_PUBLIC_KEY_FILE",)
set_parameter("jwt_secret_key_file", ParamType.String, self, options, "JWT_SECRET_KEY_FILE")
set_parameter("jwt_secret_key", ParamType.String, self, options, "JWT_SECRET_KEY")
set_parameter("jwt_public_key", ParamType.String, self, options, "JWT_PUBLIC_KEY")
set_parameter("clock_tolerance", ParamType.Integer, self, options, "OAUTH_CLOCK_TOLERANCE")
set_parameter("persist_access_token", ParamType.Boolean, self, options, "OAUTH_PERSIST_ACCESS_TOKEN")
if self._jwt_public_key and not self._jwt_key_type:
raise ValueError("If specifying jwtPublic key, must also specify jwtKeyType")
if (self._jwt_secret_key or self.__jwt_secret_key_file):
if (self._jwt_public_key or self.__jwt_public_key_file):
raise CrossauthError(ErrorCode.Configuration,
"Cannot specify symmetric and public/private JWT keys")
if (self._jwt_secret_key and self.__jwt_secret_key_file):
raise CrossauthError(ErrorCode.Configuration,
"Cannot specify symmetric key and file")
if (self.__jwt_secret_key_file is not None):
with open(self.__jwt_secret_key_file, 'rb') as f:
self._jwt_secret_key = f.read()
elif ((self._jwt_public_key is not None or self._jwt_public_key_file is not None)):
if (self._jwt_public_key_file is not None and self._jwt_public_key is not None):
raise CrossauthError(ErrorCode.Configuration,
"Cannot specify both public key and public key file")
if (self._jwt_public_key_file):
with open(self._jwt_public_key_file, 'rb') as f:
self._jwt_public_key = f.read()
[docs]
async def load_keys(self):
"""
The RSA public keys or symmetric keys for the authorization server,
either passed to the constructor or fetched from the authorization
server.
"""
try:
if self._jwt_secret_key:
#if not self._jwt_key_type:
# raise ValueError("Must specify jwtKeyType if setting jwt_secret_key")
jwk : AbstractJWKBase = jwk_from_pem(self._jwt_secret_key)
self.keys["_default"] = jwk
elif self._jwt_public_key:
#if not self._jwt_key_type:
# raise ValueError("Must specify jwtKeyType if setting jwt_public_key")
jwk : AbstractJWKBase = jwk_from_pem(self._jwt_public_key)
self.keys["_default"] = jwk
else:
if not self.oidc_config:
await self.load_config()
if not self.oidc_config:
raise ValueError("Load OIDC config before Jwks")
await self.load_jwks()
except Exception as e:
ce = CrossauthError.as_crossauth_error(e)
CrossauthLogger.logger().debug(j({"err": e}))
CrossauthLogger.logger().error(j({"msg": "Couldn't load keys", "cerr": ce}))
raise ValueError("Couldn't load keys")
[docs]
async def load_config(self, oidc_config: Optional[Dict[str, Any]] = None):
"""
Loads OpenID Connect configuration, or fetches it from the
authorization server (using the well-known enpoint appended
to `authServerBaseUrl` )
:param Dict[str, Any]|None oidcConfig: the configuration, or undefined to load it from
the authorization server
:raises :class:`common!CrossauthError`: object with :class:`ErrorCode` of
- `Connection` if the fetch to the authorization server failed.
"""
if oidc_config:
self._oidc_config = oidc_config
return
if not self._auth_server_base_url:
raise ValueError("Couldn't get OIDC configuration. Either set authServerBaseUrl or set config manually")
try:
async with aiohttp.ClientSession() as session:
resp = await session.get(f"{self._auth_server_base_url}/.well-known/openid-configuration")
if not resp:
raise ValueError("Couldn't get OIDC configuration")
self._oidc_config = {}
body = await resp.json()
self._oidc_config = {**body}
except Exception as e:
CrossauthLogger.logger().debug(j({"err": e}))
raise ValueError("Unrecognized response from OIDC configuration endpoint")
[docs]
async def load_jwks(self, jwks: Optional[Keys] = None):
"""
Loads the JWT signature validation keys, or fetches them from the
authorization server (using the URL in the OIDC configuration).
:param Dict[str, Any]|None jwks: the keys to load, or undefined to fetch them from
the authorization server.
:raises :class:`crossauth_backend.CrossauthError`: object with :class:`ErrorCode` of
- `Connection` if the fetch to the authorization server failed,
the OIDC configuration wasn't set or the keys could not be parsed.
"""
if jwks:
self.keys = {}
for jwk in jwks['keys']:
kid : str = jwk.get_kid()
self.keys[kid] = jwk
else:
if not self.oidc_config:
raise ValueError("Load OIDC config before Jwks")
try:
async with aiohttp.ClientSession() as session:
resp = await session.get(self.oidc_config['jwks_uri'])
if not resp:
raise ValueError("Couldn't get OIDC configuration")
self.keys = {}
body = await resp.json()
if 'keys' not in body or not isinstance(body['keys'], list):
raise ValueError("Couldn't fetch keys")
for k in body['keys']:
kid = k.get('kid', "_default")
self.keys[kid] = jwk_from_dict(k)
except Exception as e:
CrossauthLogger.logger().debug(j({"cerr": e}))
raise ValueError("Unrecognized response from OIDC jwks endpoint")
async def _token_authorized(self, token: str, token_type: str) -> Optional[Dict[str, Any]]:
if not self.keys or len(self.keys) == 0:
await self.load_keys()
decoded = await self._validate_token(token)
if not decoded:
return None
if decoded.get('type') != token_type:
CrossauthLogger.logger().error(j({"msg": token_type + " expected but got " + decoded.get('type', "")}))
return None
if decoded.get('iss') != self._auth_server_base_url:
CrossauthLogger.logger().error(j({"msg": f"Invalid issuer: {decoded.get('iss')} in access token", "hashedAccessToken": self.hash(decoded["jti"])}))
return None
if 'aud' in decoded:
if (isinstance(decoded['aud'], list) and self._audience not in decoded['aud']) or \
(not isinstance(decoded['aud'], list) and decoded['aud'] != self._audience):
CrossauthLogger.logger().warn(j({"msg": f"Invalid audience: {decoded['aud']} in access token"}))
return None
return decoded
[docs]
async def token_authorized(self, token: str, tokenType: Literal["access", "refresh", "id"]) -> Optional[Dict[str, Any]]:
"""
If the given token is valid, the payload is returned. Otherwise
undefined is returned.
The signature must be valid, the expiry must not have passed and,
if `tokenType` is defined,. the `type` claim in the payload must
match it.
Doesn't throw exceptions.
:param str token: The token to validate
:param Literal["access", "refresh", "id"] token_type: If defined, the `type` claim in the payload must
match this value
"""
payload = await self._token_authorized(token, tokenType)
if payload:
if tokenType == "access" and self._persist_access_token and self.__key_storage:
try:
key = KeyPrefix.access_token + Crypto.hash(payload['jti'])
token_in_storage = await self.__key_storage.get_key(key)
now = datetime.now()
if "expires" in token_in_storage and MapGetter[datetime].get(token_in_storage, "expires", now).timestamp() < now.timestamp():
CrossauthLogger.logger().error(j({"msg":"Access token expired in storage but not in JWT"}))
return None
except Exception as e:
CrossauthLogger.logger().warn(j({
"msg": "Couldn't get token from database - is it valid?",
"hashedAccessToken": Crypto.hash(payload['jti'])
}))
CrossauthLogger.logger().debug(j({"err": str(e)}))
return None
return payload
async def _validate_token(self, access_token: str) -> Optional[Dict[str, Any]]:
if not self.keys or len(self.keys) == 0:
CrossauthLogger.logger().warn(j({"msg": "No keys loaded so cannot validate tokens"}))
kid : str|None = None
instance = JWT()
try:
header = OAuthTokenConsumer.decode_header(access_token)
kid = header.get('kid', "_default")
except Exception as e:
ce = CrossauthError.as_crossauth_error(e)
CrossauthLogger.logger().debug(j({"err": e}))
CrossauthLogger.logger().warn(j({"msg": "Invalid access token format", "cerr": ce}))
return None
key = self.keys.get("_default")
for loaded_kid in self.keys:
if kid == loaded_kid:
key = self.keys[loaded_kid]
break
if not key:
CrossauthLogger.logger().warn(j({"msg": "No matching keys found for access token"}));
return None
try:
decoded_payload = instance.decode(access_token, key)
# if decoded_payload['exp'] * 1000 < (datetime.now().timestamp() + self._clock_tolerance):
# CrossauthLogger.logger.warn(j({("msg": "Access token has expired"}))
# return None
return decoded_payload
except Exception as e:
CrossauthLogger.logger().debug(j({"err":e}))
return None
[docs]
def hash(self, plaintext : str) -> str:
return Crypto.hash(plaintext);