123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548 |
- import datetime
- from typing import Any
- from typing import Callable
- from typing import Optional
- import jwt
- from flask import Flask
- from jwt import DecodeError
- from jwt import ExpiredSignatureError
- from jwt import InvalidAudienceError
- from jwt import InvalidIssuerError
- from jwt import InvalidTokenError
- from jwt import MissingRequiredClaimError
- from flask_jwt_extended.config import config
- from flask_jwt_extended.default_callbacks import default_additional_claims_callback
- from flask_jwt_extended.default_callbacks import default_blocklist_callback
- from flask_jwt_extended.default_callbacks import default_decode_key_callback
- from flask_jwt_extended.default_callbacks import default_encode_key_callback
- from flask_jwt_extended.default_callbacks import default_expired_token_callback
- from flask_jwt_extended.default_callbacks import default_invalid_token_callback
- from flask_jwt_extended.default_callbacks import default_jwt_headers_callback
- from flask_jwt_extended.default_callbacks import default_needs_fresh_token_callback
- from flask_jwt_extended.default_callbacks import default_revoked_token_callback
- from flask_jwt_extended.default_callbacks import default_token_verification_callback
- from flask_jwt_extended.default_callbacks import (
- default_token_verification_failed_callback,
- )
- from flask_jwt_extended.default_callbacks import default_unauthorized_callback
- from flask_jwt_extended.default_callbacks import default_user_identity_callback
- from flask_jwt_extended.default_callbacks import default_user_lookup_error_callback
- from flask_jwt_extended.exceptions import CSRFError
- from flask_jwt_extended.exceptions import FreshTokenRequired
- from flask_jwt_extended.exceptions import InvalidHeaderError
- from flask_jwt_extended.exceptions import InvalidQueryParamError
- from flask_jwt_extended.exceptions import JWTDecodeError
- from flask_jwt_extended.exceptions import NoAuthorizationError
- from flask_jwt_extended.exceptions import RevokedTokenError
- from flask_jwt_extended.exceptions import UserClaimsVerificationError
- from flask_jwt_extended.exceptions import UserLookupError
- from flask_jwt_extended.exceptions import WrongTokenError
- from flask_jwt_extended.tokens import _decode_jwt
- from flask_jwt_extended.tokens import _encode_jwt
- from flask_jwt_extended.typing import ExpiresDelta
- class JWTManager(object):
- """
- An object used to hold JWT settings and callback functions for the
- Flask-JWT-Extended extension.
- Instances of :class:`JWTManager` are *not* bound to specific apps, so
- you can create one in the main body of your code and then bind it
- to your app in a factory function.
- """
- def __init__(self, app: Flask = None) -> None:
- """
- Create the JWTManager instance. You can either pass a flask application
- in directly here to register this extension with the flask app, or
- call init_app after creating this object (in a factory pattern).
- :param app:
- The Flask Application object
- """
- # Register the default error handler callback methods. These can be
- # overridden with the appropriate loader decorators
- self._decode_key_callback = default_decode_key_callback
- self._encode_key_callback = default_encode_key_callback
- self._expired_token_callback = default_expired_token_callback
- self._invalid_token_callback = default_invalid_token_callback
- self._jwt_additional_header_callback = default_jwt_headers_callback
- self._needs_fresh_token_callback = default_needs_fresh_token_callback
- self._revoked_token_callback = default_revoked_token_callback
- self._token_in_blocklist_callback = default_blocklist_callback
- self._token_verification_callback = default_token_verification_callback
- self._unauthorized_callback = default_unauthorized_callback
- self._user_claims_callback = default_additional_claims_callback
- self._user_identity_callback = default_user_identity_callback
- self._user_lookup_callback: Optional[Callable] = None
- self._user_lookup_error_callback = default_user_lookup_error_callback
- self._token_verification_failed_callback = (
- default_token_verification_failed_callback
- )
- # Register this extension with the flask app now (if it is provided)
- if app is not None:
- self.init_app(app)
- def init_app(self, app: Flask) -> None:
- """
- Register this extension with the flask app.
- :param app:
- The Flask Application object
- """
- # Save this so we can use it later in the extension
- if not hasattr(app, "extensions"): # pragma: no cover
- app.extensions = {}
- app.extensions["flask-jwt-extended"] = self
- # Set all the default configurations for this extension
- self._set_default_configuration_options(app)
- self._set_error_handler_callbacks(app)
- def _set_error_handler_callbacks(self, app: Flask) -> None:
- @app.errorhandler(CSRFError)
- def handle_csrf_error(e):
- return self._unauthorized_callback(str(e))
- @app.errorhandler(DecodeError)
- def handle_decode_error(e):
- return self._invalid_token_callback(str(e))
- @app.errorhandler(ExpiredSignatureError)
- def handle_expired_error(e):
- return self._expired_token_callback(e.jwt_header, e.jwt_data)
- @app.errorhandler(FreshTokenRequired)
- def handle_fresh_token_required(e):
- return self._needs_fresh_token_callback(e.jwt_header, e.jwt_data)
- @app.errorhandler(MissingRequiredClaimError)
- def handle_missing_required_claim_error(e):
- return self._invalid_token_callback(str(e))
- @app.errorhandler(InvalidAudienceError)
- def handle_invalid_audience_error(e):
- return self._invalid_token_callback(str(e))
- @app.errorhandler(InvalidIssuerError)
- def handle_invalid_issuer_error(e):
- return self._invalid_token_callback(str(e))
- @app.errorhandler(InvalidHeaderError)
- def handle_invalid_header_error(e):
- return self._invalid_token_callback(str(e))
- @app.errorhandler(InvalidTokenError)
- def handle_invalid_token_error(e):
- return self._invalid_token_callback(str(e))
- @app.errorhandler(JWTDecodeError)
- def handle_jwt_decode_error(e):
- return self._invalid_token_callback(str(e))
- @app.errorhandler(NoAuthorizationError)
- def handle_auth_error(e):
- return self._unauthorized_callback(str(e))
- @app.errorhandler(InvalidQueryParamError)
- def handle_invalid_query_param_error(e):
- return self._invalid_token_callback(str(e))
- @app.errorhandler(RevokedTokenError)
- def handle_revoked_token_error(e):
- return self._revoked_token_callback(e.jwt_header, e.jwt_data)
- @app.errorhandler(UserClaimsVerificationError)
- def handle_failed_token_verification(e):
- return self._token_verification_failed_callback(e.jwt_header, e.jwt_data)
- @app.errorhandler(UserLookupError)
- def handler_user_lookup_error(e):
- return self._user_lookup_error_callback(e.jwt_header, e.jwt_data)
- @app.errorhandler(WrongTokenError)
- def handle_wrong_token_error(e):
- return self._invalid_token_callback(str(e))
- @staticmethod
- def _set_default_configuration_options(app: Flask) -> None:
- app.config.setdefault(
- "JWT_ACCESS_TOKEN_EXPIRES", datetime.timedelta(minutes=15)
- )
- app.config.setdefault("JWT_ACCESS_COOKIE_NAME", "access_token_cookie")
- app.config.setdefault("JWT_ACCESS_COOKIE_PATH", "/")
- app.config.setdefault("JWT_ACCESS_CSRF_COOKIE_NAME", "csrf_access_token")
- app.config.setdefault("JWT_ACCESS_CSRF_COOKIE_PATH", "/")
- app.config.setdefault("JWT_ACCESS_CSRF_FIELD_NAME", "csrf_token")
- app.config.setdefault("JWT_ACCESS_CSRF_HEADER_NAME", "X-CSRF-TOKEN")
- app.config.setdefault("JWT_ALGORITHM", "HS256")
- app.config.setdefault("JWT_COOKIE_CSRF_PROTECT", True)
- app.config.setdefault("JWT_COOKIE_DOMAIN", None)
- app.config.setdefault("JWT_COOKIE_SAMESITE", None)
- app.config.setdefault("JWT_COOKIE_SECURE", False)
- app.config.setdefault("JWT_CSRF_CHECK_FORM", False)
- app.config.setdefault("JWT_CSRF_IN_COOKIES", True)
- app.config.setdefault("JWT_CSRF_METHODS", ["POST", "PUT", "PATCH", "DELETE"])
- app.config.setdefault("JWT_DECODE_ALGORITHMS", None)
- app.config.setdefault("JWT_DECODE_AUDIENCE", None)
- app.config.setdefault("JWT_DECODE_ISSUER", None)
- app.config.setdefault("JWT_DECODE_LEEWAY", 0)
- app.config.setdefault("JWT_ENCODE_AUDIENCE", None)
- app.config.setdefault("JWT_ENCODE_ISSUER", None)
- app.config.setdefault("JWT_ERROR_MESSAGE_KEY", "msg")
- app.config.setdefault("JWT_HEADER_NAME", "Authorization")
- app.config.setdefault("JWT_HEADER_TYPE", "Bearer")
- app.config.setdefault("JWT_IDENTITY_CLAIM", "sub")
- app.config.setdefault("JWT_JSON_KEY", "access_token")
- app.config.setdefault("JWT_PRIVATE_KEY", None)
- app.config.setdefault("JWT_PUBLIC_KEY", None)
- app.config.setdefault("JWT_QUERY_STRING_NAME", "jwt")
- app.config.setdefault("JWT_QUERY_STRING_VALUE_PREFIX", "")
- app.config.setdefault("JWT_REFRESH_COOKIE_NAME", "refresh_token_cookie")
- app.config.setdefault("JWT_REFRESH_COOKIE_PATH", "/")
- app.config.setdefault("JWT_REFRESH_CSRF_COOKIE_NAME", "csrf_refresh_token")
- app.config.setdefault("JWT_REFRESH_CSRF_COOKIE_PATH", "/")
- app.config.setdefault("JWT_REFRESH_CSRF_FIELD_NAME", "csrf_token")
- app.config.setdefault("JWT_REFRESH_CSRF_HEADER_NAME", "X-CSRF-TOKEN")
- app.config.setdefault("JWT_REFRESH_JSON_KEY", "refresh_token")
- app.config.setdefault("JWT_REFRESH_TOKEN_EXPIRES", datetime.timedelta(days=30))
- app.config.setdefault("JWT_SECRET_KEY", None)
- app.config.setdefault("JWT_SESSION_COOKIE", True)
- app.config.setdefault("JWT_TOKEN_LOCATION", ("headers",))
- app.config.setdefault("JWT_ENCODE_NBF", True)
- def additional_claims_loader(self, callback: Callable) -> Callable:
- """
- This decorator sets the callback function used to add additional claims
- when creating a JWT. The claims returned by this function will be merged
- with any claims passed in via the ``additional_claims`` argument to
- :func:`~flask_jwt_extended.create_access_token` or
- :func:`~flask_jwt_extended.create_refresh_token`.
- The decorated function must take **one** argument.
- The argument is the identity that was used when creating a JWT.
- The decorated function must return a dictionary of claims to add to the JWT.
- """
- self._user_claims_callback = callback
- return callback
- def additional_headers_loader(self, callback: Callable) -> Callable:
- """
- This decorator sets the callback function used to add additional headers
- when creating a JWT. The headers returned by this function will be merged
- with any headers passed in via the ``additional_headers`` argument to
- :func:`~flask_jwt_extended.create_access_token` or
- :func:`~flask_jwt_extended.create_refresh_token`.
- The decorated function must take **one** argument.
- The argument is the identity that was used when creating a JWT.
- The decorated function must return a dictionary of headers to add to the JWT.
- """
- self._jwt_additional_header_callback = callback
- return callback
- def decode_key_loader(self, callback: Callable) -> Callable:
- """
- This decorator sets the callback function for dynamically setting the JWT
- decode key based on the **UNVERIFIED** contents of the token. Think
- carefully before using this functionality, in most cases you probably
- don't need it.
- The decorated function must take **two** arguments.
- The first argument is a dictionary containing the header data of the
- unverified JWT.
- The second argument is a dictionary containing the payload data of the
- unverified JWT.
- The decorated function must return a *string* that is used to decode and
- verify the token.
- """
- self._decode_key_callback = callback
- return callback
- def encode_key_loader(self, callback: Callable) -> Callable:
- """
- This decorator sets the callback function for dynamically setting the JWT
- encode key based on the tokens identity. Think carefully before using this
- functionality, in most cases you probably don't need it.
- The decorated function must take **one** argument.
- The argument is the identity used to create this JWT.
- The decorated function must return a *string* which is the secrete key used to
- encode the JWT.
- """
- self._encode_key_callback = callback
- return callback
- def expired_token_loader(self, callback: Callable) -> Callable:
- """
- This decorator sets the callback function for returning a custom
- response when an expired JWT is encountered.
- The decorated function must take **two** arguments.
- The first argument is a dictionary containing the header data of the JWT.
- The second argument is a dictionary containing the payload data of the JWT.
- The decorated function must return a Flask Response.
- """
- self._expired_token_callback = callback
- return callback
- def invalid_token_loader(self, callback: Callable) -> Callable:
- """
- This decorator sets the callback function for returning a custom
- response when an invalid JWT is encountered.
- This decorator sets the callback function that will be used if an
- invalid JWT attempts to access a protected endpoint.
- The decorated function must take **one** argument.
- The argument is a string which contains the reason why a token is invalid.
- The decorated function must return a Flask Response.
- """
- self._invalid_token_callback = callback
- return callback
- def needs_fresh_token_loader(self, callback: Callable) -> Callable:
- """
- This decorator sets the callback function for returning a custom
- response when a valid and non-fresh token is used on an endpoint
- that is marked as ``fresh=True``.
- The decorated function must take **two** arguments.
- The first argument is a dictionary containing the header data of the JWT.
- The second argument is a dictionary containing the payload data of the JWT.
- The decorated function must return a Flask Response.
- """
- self._needs_fresh_token_callback = callback
- return callback
- def revoked_token_loader(self, callback: Callable) -> Callable:
- """
- This decorator sets the callback function for returning a custom
- response when a revoked token is encountered.
- The decorated function must take **two** arguments.
- The first argument is a dictionary containing the header data of the JWT.
- The second argument is a dictionary containing the payload data of the JWT.
- The decorated function must return a Flask Response.
- """
- self._revoked_token_callback = callback
- return callback
- def token_in_blocklist_loader(self, callback: Callable) -> Callable:
- """
- This decorator sets the callback function used to check if a JWT has
- been revoked.
- The decorated function must take **two** arguments.
- The first argument is a dictionary containing the header data of the JWT.
- The second argument is a dictionary containing the payload data of the JWT.
- The decorated function must be return ``True`` if the token has been
- revoked, ``False`` otherwise.
- """
- self._token_in_blocklist_callback = callback
- return callback
- def token_verification_failed_loader(self, callback: Callable) -> Callable:
- """
- This decorator sets the callback function used to return a custom
- response when the claims verification check fails.
- The decorated function must take **two** arguments.
- The first argument is a dictionary containing the header data of the JWT.
- The second argument is a dictionary containing the payload data of the JWT.
- The decorated function must return a Flask Response.
- """
- self._token_verification_failed_callback = callback
- return callback
- def token_verification_loader(self, callback: Callable) -> Callable:
- """
- This decorator sets the callback function used for custom verification
- of a valid JWT.
- The decorated function must take **two** arguments.
- The first argument is a dictionary containing the header data of the JWT.
- The second argument is a dictionary containing the payload data of the JWT.
- The decorated function must return ``True`` if the token is valid, or
- ``False`` otherwise.
- """
- self._token_verification_callback = callback
- return callback
- def unauthorized_loader(self, callback: Callable) -> Callable:
- """
- This decorator sets the callback function used to return a custom
- response when no JWT is present.
- The decorated function must take **one** argument.
- The argument is a string that explains why the JWT could not be found.
- The decorated function must return a Flask Response.
- """
- self._unauthorized_callback = callback
- return callback
- def user_identity_loader(self, callback: Callable) -> Callable:
- """
- This decorator sets the callback function used to convert an identity to
- a JSON serializable format when creating JWTs. This is useful for
- using objects (such as SQLAlchemy instances) as the identity when
- creating your tokens.
- The decorated function must take **one** argument.
- The argument is the identity that was used when creating a JWT.
- The decorated function must return JSON serializable data.
- """
- self._user_identity_callback = callback
- return callback
- def user_lookup_loader(self, callback: Callable) -> Callable:
- """
- This decorator sets the callback function used to convert a JWT into
- a python object that can be used in a protected endpoint. This is useful
- for automatically loading a SQLAlchemy instance based on the contents
- of the JWT.
- The object returned from this function can be accessed via
- :attr:`~flask_jwt_extended.current_user` or
- :meth:`~flask_jwt_extended.get_current_user`
- The decorated function must take **two** arguments.
- The first argument is a dictionary containing the header data of the JWT.
- The second argument is a dictionary containing the payload data of the JWT.
- The decorated function can return any python object, which can then be
- accessed in a protected endpoint. If an object cannot be loaded, for
- example if a user has been deleted from your database, ``None`` must be
- returned to indicate that an error occurred loading the user.
- """
- self._user_lookup_callback = callback
- return callback
- def user_lookup_error_loader(self, callback: Callable) -> Callable:
- """
- This decorator sets the callback function used to return a custom
- response when loading a user via
- :meth:`~flask_jwt_extended.JWTManager.user_lookup_loader` fails.
- The decorated function must take **two** arguments.
- The first argument is a dictionary containing the header data of the JWT.
- The second argument is a dictionary containing the payload data of the JWT.
- The decorated function must return a Flask Response.
- """
- self._user_lookup_error_callback = callback
- return callback
- def _encode_jwt_from_config(
- self,
- identity: Any,
- token_type: str,
- claims=None,
- fresh: bool = False,
- expires_delta: ExpiresDelta = None,
- headers=None,
- ) -> str:
- header_overrides = self._jwt_additional_header_callback(identity)
- if headers is not None:
- header_overrides.update(headers)
- claim_overrides = self._user_claims_callback(identity)
- if claims is not None:
- claim_overrides.update(claims)
- if expires_delta is None:
- if token_type == "access":
- expires_delta = config.access_expires
- else:
- expires_delta = config.refresh_expires
- return _encode_jwt(
- algorithm=config.algorithm,
- audience=config.encode_audience,
- claim_overrides=claim_overrides,
- csrf=config.csrf_protect,
- expires_delta=expires_delta,
- fresh=fresh,
- header_overrides=header_overrides,
- identity=self._user_identity_callback(identity),
- identity_claim_key=config.identity_claim_key,
- issuer=config.encode_issuer,
- json_encoder=config.json_encoder,
- secret=self._encode_key_callback(identity),
- token_type=token_type,
- nbf=config.encode_nbf,
- )
- def _decode_jwt_from_config(
- self, encoded_token: str, csrf_value=None, allow_expired: bool = False
- ) -> dict:
- unverified_claims = jwt.decode(
- encoded_token,
- algorithms=config.decode_algorithms,
- options={"verify_signature": False},
- )
- unverified_headers = jwt.get_unverified_header(encoded_token)
- secret = self._decode_key_callback(unverified_headers, unverified_claims)
- kwargs = {
- "algorithms": config.decode_algorithms,
- "audience": config.decode_audience,
- "csrf_value": csrf_value,
- "encoded_token": encoded_token,
- "identity_claim_key": config.identity_claim_key,
- "issuer": config.decode_issuer,
- "leeway": config.leeway,
- "secret": secret,
- "verify_aud": config.decode_audience is not None,
- }
- try:
- return _decode_jwt(**kwargs, allow_expired=allow_expired)
- except ExpiredSignatureError as e:
- # TODO: If we ever do another breaking change, don't raise this pyjwt
- # error directly, instead raise a custom error of ours from this
- # error.
- e.jwt_header = unverified_headers # type: ignore
- e.jwt_data = _decode_jwt(**kwargs, allow_expired=True) # type: ignore
- raise
|