tokens.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import uuid
  2. from datetime import datetime
  3. from datetime import timedelta
  4. from datetime import timezone
  5. from hmac import compare_digest
  6. from typing import Any
  7. from typing import Iterable
  8. from typing import List
  9. from typing import Type
  10. from typing import Union
  11. import jwt
  12. from flask.json import JSONEncoder
  13. from flask_jwt_extended.exceptions import CSRFError
  14. from flask_jwt_extended.exceptions import JWTDecodeError
  15. from flask_jwt_extended.typing import ExpiresDelta
  16. def _encode_jwt(
  17. algorithm: str,
  18. audience: Union[str, Iterable[str]],
  19. claim_overrides: dict,
  20. csrf: bool,
  21. expires_delta: ExpiresDelta,
  22. fresh: bool,
  23. header_overrides: dict,
  24. identity: Any,
  25. identity_claim_key: str,
  26. issuer: str,
  27. json_encoder: Type[JSONEncoder],
  28. secret: str,
  29. token_type: str,
  30. nbf: bool,
  31. ) -> str:
  32. now = datetime.now(timezone.utc)
  33. if isinstance(fresh, timedelta):
  34. fresh = datetime.timestamp(now + fresh)
  35. token_data = {
  36. "fresh": fresh,
  37. "iat": now,
  38. "jti": str(uuid.uuid4()),
  39. "type": token_type,
  40. identity_claim_key: identity,
  41. }
  42. if nbf:
  43. token_data["nbf"] = now
  44. if csrf:
  45. token_data["csrf"] = str(uuid.uuid4())
  46. if audience:
  47. token_data["aud"] = audience
  48. if issuer:
  49. token_data["iss"] = issuer
  50. if expires_delta:
  51. token_data["exp"] = now + expires_delta
  52. if claim_overrides:
  53. token_data.update(claim_overrides)
  54. return jwt.encode(
  55. token_data,
  56. secret,
  57. algorithm,
  58. json_encoder=json_encoder, # type: ignore
  59. headers=header_overrides,
  60. )
  61. def _decode_jwt(
  62. algorithms: List,
  63. allow_expired: bool,
  64. audience: Union[str, Iterable[str]],
  65. csrf_value: str,
  66. encoded_token: str,
  67. identity_claim_key: str,
  68. issuer: str,
  69. leeway: int,
  70. secret: str,
  71. verify_aud: bool,
  72. ) -> dict:
  73. options = {"verify_aud": verify_aud}
  74. if allow_expired:
  75. options["verify_exp"] = False
  76. # This call verifies the ext, iat, and nbf claims
  77. # This optionally verifies the exp and aud claims if enabled
  78. decoded_token = jwt.decode(
  79. encoded_token,
  80. secret,
  81. algorithms=algorithms,
  82. audience=audience,
  83. issuer=issuer,
  84. leeway=leeway,
  85. options=options,
  86. )
  87. # Make sure that any custom claims we expect in the token are present
  88. if identity_claim_key not in decoded_token:
  89. raise JWTDecodeError("Missing claim: {}".format(identity_claim_key))
  90. if "type" not in decoded_token:
  91. decoded_token["type"] = "access"
  92. if "fresh" not in decoded_token:
  93. decoded_token["fresh"] = False
  94. if "jti" not in decoded_token:
  95. decoded_token["jti"] = None
  96. if csrf_value:
  97. if "csrf" not in decoded_token:
  98. raise JWTDecodeError("Missing claim: csrf")
  99. if not compare_digest(decoded_token["csrf"], csrf_value):
  100. raise CSRFError("CSRF double submit tokens do not match")
  101. return decoded_token