123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123 |
- import uuid
- from datetime import datetime
- from datetime import timedelta
- from datetime import timezone
- from hmac import compare_digest
- from typing import Any
- from typing import Iterable
- from typing import List
- from typing import Type
- from typing import Union
- import jwt
- from flask.json import JSONEncoder
- from flask_jwt_extended.exceptions import CSRFError
- from flask_jwt_extended.exceptions import JWTDecodeError
- from flask_jwt_extended.typing import ExpiresDelta
- def _encode_jwt(
- algorithm: str,
- audience: Union[str, Iterable[str]],
- claim_overrides: dict,
- csrf: bool,
- expires_delta: ExpiresDelta,
- fresh: bool,
- header_overrides: dict,
- identity: Any,
- identity_claim_key: str,
- issuer: str,
- json_encoder: Type[JSONEncoder],
- secret: str,
- token_type: str,
- nbf: bool,
- ) -> str:
- now = datetime.now(timezone.utc)
- if isinstance(fresh, timedelta):
- fresh = datetime.timestamp(now + fresh)
- token_data = {
- "fresh": fresh,
- "iat": now,
- "jti": str(uuid.uuid4()),
- "type": token_type,
- identity_claim_key: identity,
- }
- if nbf:
- token_data["nbf"] = now
- if csrf:
- token_data["csrf"] = str(uuid.uuid4())
- if audience:
- token_data["aud"] = audience
- if issuer:
- token_data["iss"] = issuer
- if expires_delta:
- token_data["exp"] = now + expires_delta
- if claim_overrides:
- token_data.update(claim_overrides)
- return jwt.encode(
- token_data,
- secret,
- algorithm,
- json_encoder=json_encoder, # type: ignore
- headers=header_overrides,
- )
- def _decode_jwt(
- algorithms: List,
- allow_expired: bool,
- audience: Union[str, Iterable[str]],
- csrf_value: str,
- encoded_token: str,
- identity_claim_key: str,
- issuer: str,
- leeway: int,
- secret: str,
- verify_aud: bool,
- ) -> dict:
- options = {"verify_aud": verify_aud}
- if allow_expired:
- options["verify_exp"] = False
- # This call verifies the ext, iat, and nbf claims
- # This optionally verifies the exp and aud claims if enabled
- decoded_token = jwt.decode(
- encoded_token,
- secret,
- algorithms=algorithms,
- audience=audience,
- issuer=issuer,
- leeway=leeway,
- options=options,
- )
- # Make sure that any custom claims we expect in the token are present
- if identity_claim_key not in decoded_token:
- raise JWTDecodeError("Missing claim: {}".format(identity_claim_key))
- if "type" not in decoded_token:
- decoded_token["type"] = "access"
- if "fresh" not in decoded_token:
- decoded_token["fresh"] = False
- if "jti" not in decoded_token:
- decoded_token["jti"] = None
- if csrf_value:
- if "csrf" not in decoded_token:
- raise JWTDecodeError("Missing claim: csrf")
- if not compare_digest(decoded_token["csrf"], csrf_value):
- raise CSRFError("CSRF double submit tokens do not match")
- return decoded_token
|