jwt_manager.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548
  1. import datetime
  2. from typing import Any
  3. from typing import Callable
  4. from typing import Optional
  5. import jwt
  6. from flask import Flask
  7. from jwt import DecodeError
  8. from jwt import ExpiredSignatureError
  9. from jwt import InvalidAudienceError
  10. from jwt import InvalidIssuerError
  11. from jwt import InvalidTokenError
  12. from jwt import MissingRequiredClaimError
  13. from flask_jwt_extended.config import config
  14. from flask_jwt_extended.default_callbacks import default_additional_claims_callback
  15. from flask_jwt_extended.default_callbacks import default_blocklist_callback
  16. from flask_jwt_extended.default_callbacks import default_decode_key_callback
  17. from flask_jwt_extended.default_callbacks import default_encode_key_callback
  18. from flask_jwt_extended.default_callbacks import default_expired_token_callback
  19. from flask_jwt_extended.default_callbacks import default_invalid_token_callback
  20. from flask_jwt_extended.default_callbacks import default_jwt_headers_callback
  21. from flask_jwt_extended.default_callbacks import default_needs_fresh_token_callback
  22. from flask_jwt_extended.default_callbacks import default_revoked_token_callback
  23. from flask_jwt_extended.default_callbacks import default_token_verification_callback
  24. from flask_jwt_extended.default_callbacks import (
  25. default_token_verification_failed_callback,
  26. )
  27. from flask_jwt_extended.default_callbacks import default_unauthorized_callback
  28. from flask_jwt_extended.default_callbacks import default_user_identity_callback
  29. from flask_jwt_extended.default_callbacks import default_user_lookup_error_callback
  30. from flask_jwt_extended.exceptions import CSRFError
  31. from flask_jwt_extended.exceptions import FreshTokenRequired
  32. from flask_jwt_extended.exceptions import InvalidHeaderError
  33. from flask_jwt_extended.exceptions import InvalidQueryParamError
  34. from flask_jwt_extended.exceptions import JWTDecodeError
  35. from flask_jwt_extended.exceptions import NoAuthorizationError
  36. from flask_jwt_extended.exceptions import RevokedTokenError
  37. from flask_jwt_extended.exceptions import UserClaimsVerificationError
  38. from flask_jwt_extended.exceptions import UserLookupError
  39. from flask_jwt_extended.exceptions import WrongTokenError
  40. from flask_jwt_extended.tokens import _decode_jwt
  41. from flask_jwt_extended.tokens import _encode_jwt
  42. from flask_jwt_extended.typing import ExpiresDelta
  43. class JWTManager(object):
  44. """
  45. An object used to hold JWT settings and callback functions for the
  46. Flask-JWT-Extended extension.
  47. Instances of :class:`JWTManager` are *not* bound to specific apps, so
  48. you can create one in the main body of your code and then bind it
  49. to your app in a factory function.
  50. """
  51. def __init__(self, app: Flask = None) -> None:
  52. """
  53. Create the JWTManager instance. You can either pass a flask application
  54. in directly here to register this extension with the flask app, or
  55. call init_app after creating this object (in a factory pattern).
  56. :param app:
  57. The Flask Application object
  58. """
  59. # Register the default error handler callback methods. These can be
  60. # overridden with the appropriate loader decorators
  61. self._decode_key_callback = default_decode_key_callback
  62. self._encode_key_callback = default_encode_key_callback
  63. self._expired_token_callback = default_expired_token_callback
  64. self._invalid_token_callback = default_invalid_token_callback
  65. self._jwt_additional_header_callback = default_jwt_headers_callback
  66. self._needs_fresh_token_callback = default_needs_fresh_token_callback
  67. self._revoked_token_callback = default_revoked_token_callback
  68. self._token_in_blocklist_callback = default_blocklist_callback
  69. self._token_verification_callback = default_token_verification_callback
  70. self._unauthorized_callback = default_unauthorized_callback
  71. self._user_claims_callback = default_additional_claims_callback
  72. self._user_identity_callback = default_user_identity_callback
  73. self._user_lookup_callback: Optional[Callable] = None
  74. self._user_lookup_error_callback = default_user_lookup_error_callback
  75. self._token_verification_failed_callback = (
  76. default_token_verification_failed_callback
  77. )
  78. # Register this extension with the flask app now (if it is provided)
  79. if app is not None:
  80. self.init_app(app)
  81. def init_app(self, app: Flask) -> None:
  82. """
  83. Register this extension with the flask app.
  84. :param app:
  85. The Flask Application object
  86. """
  87. # Save this so we can use it later in the extension
  88. if not hasattr(app, "extensions"): # pragma: no cover
  89. app.extensions = {}
  90. app.extensions["flask-jwt-extended"] = self
  91. # Set all the default configurations for this extension
  92. self._set_default_configuration_options(app)
  93. self._set_error_handler_callbacks(app)
  94. def _set_error_handler_callbacks(self, app: Flask) -> None:
  95. @app.errorhandler(CSRFError)
  96. def handle_csrf_error(e):
  97. return self._unauthorized_callback(str(e))
  98. @app.errorhandler(DecodeError)
  99. def handle_decode_error(e):
  100. return self._invalid_token_callback(str(e))
  101. @app.errorhandler(ExpiredSignatureError)
  102. def handle_expired_error(e):
  103. return self._expired_token_callback(e.jwt_header, e.jwt_data)
  104. @app.errorhandler(FreshTokenRequired)
  105. def handle_fresh_token_required(e):
  106. return self._needs_fresh_token_callback(e.jwt_header, e.jwt_data)
  107. @app.errorhandler(MissingRequiredClaimError)
  108. def handle_missing_required_claim_error(e):
  109. return self._invalid_token_callback(str(e))
  110. @app.errorhandler(InvalidAudienceError)
  111. def handle_invalid_audience_error(e):
  112. return self._invalid_token_callback(str(e))
  113. @app.errorhandler(InvalidIssuerError)
  114. def handle_invalid_issuer_error(e):
  115. return self._invalid_token_callback(str(e))
  116. @app.errorhandler(InvalidHeaderError)
  117. def handle_invalid_header_error(e):
  118. return self._invalid_token_callback(str(e))
  119. @app.errorhandler(InvalidTokenError)
  120. def handle_invalid_token_error(e):
  121. return self._invalid_token_callback(str(e))
  122. @app.errorhandler(JWTDecodeError)
  123. def handle_jwt_decode_error(e):
  124. return self._invalid_token_callback(str(e))
  125. @app.errorhandler(NoAuthorizationError)
  126. def handle_auth_error(e):
  127. return self._unauthorized_callback(str(e))
  128. @app.errorhandler(InvalidQueryParamError)
  129. def handle_invalid_query_param_error(e):
  130. return self._invalid_token_callback(str(e))
  131. @app.errorhandler(RevokedTokenError)
  132. def handle_revoked_token_error(e):
  133. return self._revoked_token_callback(e.jwt_header, e.jwt_data)
  134. @app.errorhandler(UserClaimsVerificationError)
  135. def handle_failed_token_verification(e):
  136. return self._token_verification_failed_callback(e.jwt_header, e.jwt_data)
  137. @app.errorhandler(UserLookupError)
  138. def handler_user_lookup_error(e):
  139. return self._user_lookup_error_callback(e.jwt_header, e.jwt_data)
  140. @app.errorhandler(WrongTokenError)
  141. def handle_wrong_token_error(e):
  142. return self._invalid_token_callback(str(e))
  143. @staticmethod
  144. def _set_default_configuration_options(app: Flask) -> None:
  145. app.config.setdefault(
  146. "JWT_ACCESS_TOKEN_EXPIRES", datetime.timedelta(minutes=15)
  147. )
  148. app.config.setdefault("JWT_ACCESS_COOKIE_NAME", "access_token_cookie")
  149. app.config.setdefault("JWT_ACCESS_COOKIE_PATH", "/")
  150. app.config.setdefault("JWT_ACCESS_CSRF_COOKIE_NAME", "csrf_access_token")
  151. app.config.setdefault("JWT_ACCESS_CSRF_COOKIE_PATH", "/")
  152. app.config.setdefault("JWT_ACCESS_CSRF_FIELD_NAME", "csrf_token")
  153. app.config.setdefault("JWT_ACCESS_CSRF_HEADER_NAME", "X-CSRF-TOKEN")
  154. app.config.setdefault("JWT_ALGORITHM", "HS256")
  155. app.config.setdefault("JWT_COOKIE_CSRF_PROTECT", True)
  156. app.config.setdefault("JWT_COOKIE_DOMAIN", None)
  157. app.config.setdefault("JWT_COOKIE_SAMESITE", None)
  158. app.config.setdefault("JWT_COOKIE_SECURE", False)
  159. app.config.setdefault("JWT_CSRF_CHECK_FORM", False)
  160. app.config.setdefault("JWT_CSRF_IN_COOKIES", True)
  161. app.config.setdefault("JWT_CSRF_METHODS", ["POST", "PUT", "PATCH", "DELETE"])
  162. app.config.setdefault("JWT_DECODE_ALGORITHMS", None)
  163. app.config.setdefault("JWT_DECODE_AUDIENCE", None)
  164. app.config.setdefault("JWT_DECODE_ISSUER", None)
  165. app.config.setdefault("JWT_DECODE_LEEWAY", 0)
  166. app.config.setdefault("JWT_ENCODE_AUDIENCE", None)
  167. app.config.setdefault("JWT_ENCODE_ISSUER", None)
  168. app.config.setdefault("JWT_ERROR_MESSAGE_KEY", "msg")
  169. app.config.setdefault("JWT_HEADER_NAME", "Authorization")
  170. app.config.setdefault("JWT_HEADER_TYPE", "Bearer")
  171. app.config.setdefault("JWT_IDENTITY_CLAIM", "sub")
  172. app.config.setdefault("JWT_JSON_KEY", "access_token")
  173. app.config.setdefault("JWT_PRIVATE_KEY", None)
  174. app.config.setdefault("JWT_PUBLIC_KEY", None)
  175. app.config.setdefault("JWT_QUERY_STRING_NAME", "jwt")
  176. app.config.setdefault("JWT_QUERY_STRING_VALUE_PREFIX", "")
  177. app.config.setdefault("JWT_REFRESH_COOKIE_NAME", "refresh_token_cookie")
  178. app.config.setdefault("JWT_REFRESH_COOKIE_PATH", "/")
  179. app.config.setdefault("JWT_REFRESH_CSRF_COOKIE_NAME", "csrf_refresh_token")
  180. app.config.setdefault("JWT_REFRESH_CSRF_COOKIE_PATH", "/")
  181. app.config.setdefault("JWT_REFRESH_CSRF_FIELD_NAME", "csrf_token")
  182. app.config.setdefault("JWT_REFRESH_CSRF_HEADER_NAME", "X-CSRF-TOKEN")
  183. app.config.setdefault("JWT_REFRESH_JSON_KEY", "refresh_token")
  184. app.config.setdefault("JWT_REFRESH_TOKEN_EXPIRES", datetime.timedelta(days=30))
  185. app.config.setdefault("JWT_SECRET_KEY", None)
  186. app.config.setdefault("JWT_SESSION_COOKIE", True)
  187. app.config.setdefault("JWT_TOKEN_LOCATION", ("headers",))
  188. app.config.setdefault("JWT_ENCODE_NBF", True)
  189. def additional_claims_loader(self, callback: Callable) -> Callable:
  190. """
  191. This decorator sets the callback function used to add additional claims
  192. when creating a JWT. The claims returned by this function will be merged
  193. with any claims passed in via the ``additional_claims`` argument to
  194. :func:`~flask_jwt_extended.create_access_token` or
  195. :func:`~flask_jwt_extended.create_refresh_token`.
  196. The decorated function must take **one** argument.
  197. The argument is the identity that was used when creating a JWT.
  198. The decorated function must return a dictionary of claims to add to the JWT.
  199. """
  200. self._user_claims_callback = callback
  201. return callback
  202. def additional_headers_loader(self, callback: Callable) -> Callable:
  203. """
  204. This decorator sets the callback function used to add additional headers
  205. when creating a JWT. The headers returned by this function will be merged
  206. with any headers passed in via the ``additional_headers`` argument to
  207. :func:`~flask_jwt_extended.create_access_token` or
  208. :func:`~flask_jwt_extended.create_refresh_token`.
  209. The decorated function must take **one** argument.
  210. The argument is the identity that was used when creating a JWT.
  211. The decorated function must return a dictionary of headers to add to the JWT.
  212. """
  213. self._jwt_additional_header_callback = callback
  214. return callback
  215. def decode_key_loader(self, callback: Callable) -> Callable:
  216. """
  217. This decorator sets the callback function for dynamically setting the JWT
  218. decode key based on the **UNVERIFIED** contents of the token. Think
  219. carefully before using this functionality, in most cases you probably
  220. don't need it.
  221. The decorated function must take **two** arguments.
  222. The first argument is a dictionary containing the header data of the
  223. unverified JWT.
  224. The second argument is a dictionary containing the payload data of the
  225. unverified JWT.
  226. The decorated function must return a *string* that is used to decode and
  227. verify the token.
  228. """
  229. self._decode_key_callback = callback
  230. return callback
  231. def encode_key_loader(self, callback: Callable) -> Callable:
  232. """
  233. This decorator sets the callback function for dynamically setting the JWT
  234. encode key based on the tokens identity. Think carefully before using this
  235. functionality, in most cases you probably don't need it.
  236. The decorated function must take **one** argument.
  237. The argument is the identity used to create this JWT.
  238. The decorated function must return a *string* which is the secrete key used to
  239. encode the JWT.
  240. """
  241. self._encode_key_callback = callback
  242. return callback
  243. def expired_token_loader(self, callback: Callable) -> Callable:
  244. """
  245. This decorator sets the callback function for returning a custom
  246. response when an expired JWT is encountered.
  247. The decorated function must take **two** arguments.
  248. The first argument is a dictionary containing the header data of the JWT.
  249. The second argument is a dictionary containing the payload data of the JWT.
  250. The decorated function must return a Flask Response.
  251. """
  252. self._expired_token_callback = callback
  253. return callback
  254. def invalid_token_loader(self, callback: Callable) -> Callable:
  255. """
  256. This decorator sets the callback function for returning a custom
  257. response when an invalid JWT is encountered.
  258. This decorator sets the callback function that will be used if an
  259. invalid JWT attempts to access a protected endpoint.
  260. The decorated function must take **one** argument.
  261. The argument is a string which contains the reason why a token is invalid.
  262. The decorated function must return a Flask Response.
  263. """
  264. self._invalid_token_callback = callback
  265. return callback
  266. def needs_fresh_token_loader(self, callback: Callable) -> Callable:
  267. """
  268. This decorator sets the callback function for returning a custom
  269. response when a valid and non-fresh token is used on an endpoint
  270. that is marked as ``fresh=True``.
  271. The decorated function must take **two** arguments.
  272. The first argument is a dictionary containing the header data of the JWT.
  273. The second argument is a dictionary containing the payload data of the JWT.
  274. The decorated function must return a Flask Response.
  275. """
  276. self._needs_fresh_token_callback = callback
  277. return callback
  278. def revoked_token_loader(self, callback: Callable) -> Callable:
  279. """
  280. This decorator sets the callback function for returning a custom
  281. response when a revoked token is encountered.
  282. The decorated function must take **two** arguments.
  283. The first argument is a dictionary containing the header data of the JWT.
  284. The second argument is a dictionary containing the payload data of the JWT.
  285. The decorated function must return a Flask Response.
  286. """
  287. self._revoked_token_callback = callback
  288. return callback
  289. def token_in_blocklist_loader(self, callback: Callable) -> Callable:
  290. """
  291. This decorator sets the callback function used to check if a JWT has
  292. been revoked.
  293. The decorated function must take **two** arguments.
  294. The first argument is a dictionary containing the header data of the JWT.
  295. The second argument is a dictionary containing the payload data of the JWT.
  296. The decorated function must be return ``True`` if the token has been
  297. revoked, ``False`` otherwise.
  298. """
  299. self._token_in_blocklist_callback = callback
  300. return callback
  301. def token_verification_failed_loader(self, callback: Callable) -> Callable:
  302. """
  303. This decorator sets the callback function used to return a custom
  304. response when the claims verification check fails.
  305. The decorated function must take **two** arguments.
  306. The first argument is a dictionary containing the header data of the JWT.
  307. The second argument is a dictionary containing the payload data of the JWT.
  308. The decorated function must return a Flask Response.
  309. """
  310. self._token_verification_failed_callback = callback
  311. return callback
  312. def token_verification_loader(self, callback: Callable) -> Callable:
  313. """
  314. This decorator sets the callback function used for custom verification
  315. of a valid JWT.
  316. The decorated function must take **two** arguments.
  317. The first argument is a dictionary containing the header data of the JWT.
  318. The second argument is a dictionary containing the payload data of the JWT.
  319. The decorated function must return ``True`` if the token is valid, or
  320. ``False`` otherwise.
  321. """
  322. self._token_verification_callback = callback
  323. return callback
  324. def unauthorized_loader(self, callback: Callable) -> Callable:
  325. """
  326. This decorator sets the callback function used to return a custom
  327. response when no JWT is present.
  328. The decorated function must take **one** argument.
  329. The argument is a string that explains why the JWT could not be found.
  330. The decorated function must return a Flask Response.
  331. """
  332. self._unauthorized_callback = callback
  333. return callback
  334. def user_identity_loader(self, callback: Callable) -> Callable:
  335. """
  336. This decorator sets the callback function used to convert an identity to
  337. a JSON serializable format when creating JWTs. This is useful for
  338. using objects (such as SQLAlchemy instances) as the identity when
  339. creating your tokens.
  340. The decorated function must take **one** argument.
  341. The argument is the identity that was used when creating a JWT.
  342. The decorated function must return JSON serializable data.
  343. """
  344. self._user_identity_callback = callback
  345. return callback
  346. def user_lookup_loader(self, callback: Callable) -> Callable:
  347. """
  348. This decorator sets the callback function used to convert a JWT into
  349. a python object that can be used in a protected endpoint. This is useful
  350. for automatically loading a SQLAlchemy instance based on the contents
  351. of the JWT.
  352. The object returned from this function can be accessed via
  353. :attr:`~flask_jwt_extended.current_user` or
  354. :meth:`~flask_jwt_extended.get_current_user`
  355. The decorated function must take **two** arguments.
  356. The first argument is a dictionary containing the header data of the JWT.
  357. The second argument is a dictionary containing the payload data of the JWT.
  358. The decorated function can return any python object, which can then be
  359. accessed in a protected endpoint. If an object cannot be loaded, for
  360. example if a user has been deleted from your database, ``None`` must be
  361. returned to indicate that an error occurred loading the user.
  362. """
  363. self._user_lookup_callback = callback
  364. return callback
  365. def user_lookup_error_loader(self, callback: Callable) -> Callable:
  366. """
  367. This decorator sets the callback function used to return a custom
  368. response when loading a user via
  369. :meth:`~flask_jwt_extended.JWTManager.user_lookup_loader` fails.
  370. The decorated function must take **two** arguments.
  371. The first argument is a dictionary containing the header data of the JWT.
  372. The second argument is a dictionary containing the payload data of the JWT.
  373. The decorated function must return a Flask Response.
  374. """
  375. self._user_lookup_error_callback = callback
  376. return callback
  377. def _encode_jwt_from_config(
  378. self,
  379. identity: Any,
  380. token_type: str,
  381. claims=None,
  382. fresh: bool = False,
  383. expires_delta: ExpiresDelta = None,
  384. headers=None,
  385. ) -> str:
  386. header_overrides = self._jwt_additional_header_callback(identity)
  387. if headers is not None:
  388. header_overrides.update(headers)
  389. claim_overrides = self._user_claims_callback(identity)
  390. if claims is not None:
  391. claim_overrides.update(claims)
  392. if expires_delta is None:
  393. if token_type == "access":
  394. expires_delta = config.access_expires
  395. else:
  396. expires_delta = config.refresh_expires
  397. return _encode_jwt(
  398. algorithm=config.algorithm,
  399. audience=config.encode_audience,
  400. claim_overrides=claim_overrides,
  401. csrf=config.csrf_protect,
  402. expires_delta=expires_delta,
  403. fresh=fresh,
  404. header_overrides=header_overrides,
  405. identity=self._user_identity_callback(identity),
  406. identity_claim_key=config.identity_claim_key,
  407. issuer=config.encode_issuer,
  408. json_encoder=config.json_encoder,
  409. secret=self._encode_key_callback(identity),
  410. token_type=token_type,
  411. nbf=config.encode_nbf,
  412. )
  413. def _decode_jwt_from_config(
  414. self, encoded_token: str, csrf_value=None, allow_expired: bool = False
  415. ) -> dict:
  416. unverified_claims = jwt.decode(
  417. encoded_token,
  418. algorithms=config.decode_algorithms,
  419. options={"verify_signature": False},
  420. )
  421. unverified_headers = jwt.get_unverified_header(encoded_token)
  422. secret = self._decode_key_callback(unverified_headers, unverified_claims)
  423. kwargs = {
  424. "algorithms": config.decode_algorithms,
  425. "audience": config.decode_audience,
  426. "csrf_value": csrf_value,
  427. "encoded_token": encoded_token,
  428. "identity_claim_key": config.identity_claim_key,
  429. "issuer": config.decode_issuer,
  430. "leeway": config.leeway,
  431. "secret": secret,
  432. "verify_aud": config.decode_audience is not None,
  433. }
  434. try:
  435. return _decode_jwt(**kwargs, allow_expired=allow_expired)
  436. except ExpiredSignatureError as e:
  437. # TODO: If we ever do another breaking change, don't raise this pyjwt
  438. # error directly, instead raise a custom error of ours from this
  439. # error.
  440. e.jwt_header = unverified_headers # type: ignore
  441. e.jwt_data = _decode_jwt(**kwargs, allow_expired=True) # type: ignore
  442. raise