api_jws.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. import binascii
  2. import json
  3. from collections.abc import Mapping
  4. from typing import Any, Dict, List, Optional, Type
  5. from .algorithms import (
  6. Algorithm,
  7. get_default_algorithms,
  8. has_crypto,
  9. requires_cryptography,
  10. )
  11. from .exceptions import (
  12. DecodeError,
  13. InvalidAlgorithmError,
  14. InvalidSignatureError,
  15. InvalidTokenError,
  16. )
  17. from .utils import base64url_decode, base64url_encode
  18. class PyJWS:
  19. header_typ = "JWT"
  20. def __init__(self, algorithms=None, options=None):
  21. self._algorithms = get_default_algorithms()
  22. self._valid_algs = (
  23. set(algorithms) if algorithms is not None else set(self._algorithms)
  24. )
  25. # Remove algorithms that aren't on the whitelist
  26. for key in list(self._algorithms.keys()):
  27. if key not in self._valid_algs:
  28. del self._algorithms[key]
  29. if options is None:
  30. options = {}
  31. self.options = {**self._get_default_options(), **options}
  32. @staticmethod
  33. def _get_default_options():
  34. return {"verify_signature": True}
  35. def register_algorithm(self, alg_id, alg_obj):
  36. """
  37. Registers a new Algorithm for use when creating and verifying tokens.
  38. """
  39. if alg_id in self._algorithms:
  40. raise ValueError("Algorithm already has a handler.")
  41. if not isinstance(alg_obj, Algorithm):
  42. raise TypeError("Object is not of type `Algorithm`")
  43. self._algorithms[alg_id] = alg_obj
  44. self._valid_algs.add(alg_id)
  45. def unregister_algorithm(self, alg_id):
  46. """
  47. Unregisters an Algorithm for use when creating and verifying tokens
  48. Throws KeyError if algorithm is not registered.
  49. """
  50. if alg_id not in self._algorithms:
  51. raise KeyError(
  52. "The specified algorithm could not be removed"
  53. " because it is not registered."
  54. )
  55. del self._algorithms[alg_id]
  56. self._valid_algs.remove(alg_id)
  57. def get_algorithms(self):
  58. """
  59. Returns a list of supported values for the 'alg' parameter.
  60. """
  61. return list(self._valid_algs)
  62. def encode(
  63. self,
  64. payload: bytes,
  65. key: str,
  66. algorithm: Optional[str] = "HS256",
  67. headers: Optional[Dict] = None,
  68. json_encoder: Optional[Type[json.JSONEncoder]] = None,
  69. is_payload_detached: bool = False,
  70. ) -> str:
  71. segments = []
  72. if algorithm is None:
  73. algorithm = "none"
  74. # Prefer headers values if present to function parameters.
  75. if headers:
  76. headers_alg = headers.get("alg")
  77. if headers_alg:
  78. algorithm = headers["alg"]
  79. headers_b64 = headers.get("b64")
  80. if headers_b64 is False:
  81. is_payload_detached = True
  82. # Header
  83. header = {"typ": self.header_typ, "alg": algorithm} # type: Dict[str, Any]
  84. if headers:
  85. self._validate_headers(headers)
  86. header.update(headers)
  87. if not header["typ"]:
  88. del header["typ"]
  89. if is_payload_detached:
  90. header["b64"] = False
  91. elif "b64" in header:
  92. # True is the standard value for b64, so no need for it
  93. del header["b64"]
  94. json_header = json.dumps(
  95. header, separators=(",", ":"), cls=json_encoder
  96. ).encode()
  97. segments.append(base64url_encode(json_header))
  98. if is_payload_detached:
  99. msg_payload = payload
  100. else:
  101. msg_payload = base64url_encode(payload)
  102. segments.append(msg_payload)
  103. # Segments
  104. signing_input = b".".join(segments)
  105. try:
  106. alg_obj = self._algorithms[algorithm]
  107. key = alg_obj.prepare_key(key)
  108. signature = alg_obj.sign(signing_input, key)
  109. except KeyError as e:
  110. if not has_crypto and algorithm in requires_cryptography:
  111. raise NotImplementedError(
  112. f"Algorithm '{algorithm}' could not be found. Do you have cryptography installed?"
  113. ) from e
  114. raise NotImplementedError("Algorithm not supported") from e
  115. segments.append(base64url_encode(signature))
  116. # Don't put the payload content inside the encoded token when detached
  117. if is_payload_detached:
  118. segments[1] = b""
  119. encoded_string = b".".join(segments)
  120. return encoded_string.decode("utf-8")
  121. def decode_complete(
  122. self,
  123. jwt: str,
  124. key: str = "",
  125. algorithms: Optional[List[str]] = None,
  126. options: Optional[Dict] = None,
  127. detached_payload: Optional[bytes] = None,
  128. **kwargs,
  129. ) -> Dict[str, Any]:
  130. if options is None:
  131. options = {}
  132. merged_options = {**self.options, **options}
  133. verify_signature = merged_options["verify_signature"]
  134. if verify_signature and not algorithms:
  135. raise DecodeError(
  136. 'It is required that you pass in a value for the "algorithms" argument when calling decode().'
  137. )
  138. payload, signing_input, header, signature = self._load(jwt)
  139. if header.get("b64", True) is False:
  140. if detached_payload is None:
  141. raise DecodeError(
  142. 'It is required that you pass in a value for the "detached_payload" argument to decode a message having the b64 header set to false.'
  143. )
  144. payload = detached_payload
  145. signing_input = b".".join([signing_input.rsplit(b".", 1)[0], payload])
  146. if verify_signature:
  147. self._verify_signature(signing_input, header, signature, key, algorithms)
  148. return {
  149. "payload": payload,
  150. "header": header,
  151. "signature": signature,
  152. }
  153. def decode(
  154. self,
  155. jwt: str,
  156. key: str = "",
  157. algorithms: Optional[List[str]] = None,
  158. options: Optional[Dict] = None,
  159. **kwargs,
  160. ) -> str:
  161. decoded = self.decode_complete(jwt, key, algorithms, options, **kwargs)
  162. return decoded["payload"]
  163. def get_unverified_header(self, jwt):
  164. """Returns back the JWT header parameters as a dict()
  165. Note: The signature is not verified so the header parameters
  166. should not be fully trusted until signature verification is complete
  167. """
  168. headers = self._load(jwt)[2]
  169. self._validate_headers(headers)
  170. return headers
  171. def _load(self, jwt):
  172. if isinstance(jwt, str):
  173. jwt = jwt.encode("utf-8")
  174. if not isinstance(jwt, bytes):
  175. raise DecodeError(f"Invalid token type. Token must be a {bytes}")
  176. try:
  177. signing_input, crypto_segment = jwt.rsplit(b".", 1)
  178. header_segment, payload_segment = signing_input.split(b".", 1)
  179. except ValueError as err:
  180. raise DecodeError("Not enough segments") from err
  181. try:
  182. header_data = base64url_decode(header_segment)
  183. except (TypeError, binascii.Error) as err:
  184. raise DecodeError("Invalid header padding") from err
  185. try:
  186. header = json.loads(header_data)
  187. except ValueError as e:
  188. raise DecodeError(f"Invalid header string: {e}") from e
  189. if not isinstance(header, Mapping):
  190. raise DecodeError("Invalid header string: must be a json object")
  191. try:
  192. payload = base64url_decode(payload_segment)
  193. except (TypeError, binascii.Error) as err:
  194. raise DecodeError("Invalid payload padding") from err
  195. try:
  196. signature = base64url_decode(crypto_segment)
  197. except (TypeError, binascii.Error) as err:
  198. raise DecodeError("Invalid crypto padding") from err
  199. return (payload, signing_input, header, signature)
  200. def _verify_signature(
  201. self,
  202. signing_input,
  203. header,
  204. signature,
  205. key="",
  206. algorithms=None,
  207. ):
  208. alg = header.get("alg")
  209. if algorithms is not None and alg not in algorithms:
  210. raise InvalidAlgorithmError("The specified alg value is not allowed")
  211. try:
  212. alg_obj = self._algorithms[alg]
  213. key = alg_obj.prepare_key(key)
  214. if not alg_obj.verify(signing_input, key, signature):
  215. raise InvalidSignatureError("Signature verification failed")
  216. except KeyError as e:
  217. raise InvalidAlgorithmError("Algorithm not supported") from e
  218. def _validate_headers(self, headers):
  219. if "kid" in headers:
  220. self._validate_kid(headers["kid"])
  221. def _validate_kid(self, kid):
  222. if not isinstance(kid, str):
  223. raise InvalidTokenError("Key ID header parameter must be a string")
  224. _jws_global_obj = PyJWS()
  225. encode = _jws_global_obj.encode
  226. decode_complete = _jws_global_obj.decode_complete
  227. decode = _jws_global_obj.decode
  228. register_algorithm = _jws_global_obj.register_algorithm
  229. unregister_algorithm = _jws_global_obj.unregister_algorithm
  230. get_unverified_header = _jws_global_obj.get_unverified_header