algorithms.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677
  1. import hashlib
  2. import hmac
  3. import json
  4. from .exceptions import InvalidKeyError
  5. from .utils import (
  6. base64url_decode,
  7. base64url_encode,
  8. der_to_raw_signature,
  9. force_bytes,
  10. from_base64url_uint,
  11. is_pem_format,
  12. is_ssh_key,
  13. raw_to_der_signature,
  14. to_base64url_uint,
  15. )
  16. try:
  17. import cryptography.exceptions
  18. from cryptography.exceptions import InvalidSignature
  19. from cryptography.hazmat.primitives import hashes
  20. from cryptography.hazmat.primitives.asymmetric import ec, padding
  21. from cryptography.hazmat.primitives.asymmetric.ec import (
  22. EllipticCurvePrivateKey,
  23. EllipticCurvePublicKey,
  24. )
  25. from cryptography.hazmat.primitives.asymmetric.ed448 import (
  26. Ed448PrivateKey,
  27. Ed448PublicKey,
  28. )
  29. from cryptography.hazmat.primitives.asymmetric.ed25519 import (
  30. Ed25519PrivateKey,
  31. Ed25519PublicKey,
  32. )
  33. from cryptography.hazmat.primitives.asymmetric.rsa import (
  34. RSAPrivateKey,
  35. RSAPrivateNumbers,
  36. RSAPublicKey,
  37. RSAPublicNumbers,
  38. rsa_crt_dmp1,
  39. rsa_crt_dmq1,
  40. rsa_crt_iqmp,
  41. rsa_recover_prime_factors,
  42. )
  43. from cryptography.hazmat.primitives.serialization import (
  44. Encoding,
  45. NoEncryption,
  46. PrivateFormat,
  47. PublicFormat,
  48. load_pem_private_key,
  49. load_pem_public_key,
  50. load_ssh_public_key,
  51. )
  52. has_crypto = True
  53. except ModuleNotFoundError:
  54. has_crypto = False
  55. requires_cryptography = {
  56. "RS256",
  57. "RS384",
  58. "RS512",
  59. "ES256",
  60. "ES256K",
  61. "ES384",
  62. "ES521",
  63. "ES512",
  64. "PS256",
  65. "PS384",
  66. "PS512",
  67. "EdDSA",
  68. }
  69. def get_default_algorithms():
  70. """
  71. Returns the algorithms that are implemented by the library.
  72. """
  73. default_algorithms = {
  74. "none": NoneAlgorithm(),
  75. "HS256": HMACAlgorithm(HMACAlgorithm.SHA256),
  76. "HS384": HMACAlgorithm(HMACAlgorithm.SHA384),
  77. "HS512": HMACAlgorithm(HMACAlgorithm.SHA512),
  78. }
  79. if has_crypto:
  80. default_algorithms.update(
  81. {
  82. "RS256": RSAAlgorithm(RSAAlgorithm.SHA256),
  83. "RS384": RSAAlgorithm(RSAAlgorithm.SHA384),
  84. "RS512": RSAAlgorithm(RSAAlgorithm.SHA512),
  85. "ES256": ECAlgorithm(ECAlgorithm.SHA256),
  86. "ES256K": ECAlgorithm(ECAlgorithm.SHA256),
  87. "ES384": ECAlgorithm(ECAlgorithm.SHA384),
  88. "ES521": ECAlgorithm(ECAlgorithm.SHA512),
  89. "ES512": ECAlgorithm(
  90. ECAlgorithm.SHA512
  91. ), # Backward compat for #219 fix
  92. "PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256),
  93. "PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384),
  94. "PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512),
  95. "EdDSA": OKPAlgorithm(),
  96. }
  97. )
  98. return default_algorithms
  99. class Algorithm:
  100. """
  101. The interface for an algorithm used to sign and verify tokens.
  102. """
  103. def prepare_key(self, key):
  104. """
  105. Performs necessary validation and conversions on the key and returns
  106. the key value in the proper format for sign() and verify().
  107. """
  108. raise NotImplementedError
  109. def sign(self, msg, key):
  110. """
  111. Returns a digital signature for the specified message
  112. using the specified key value.
  113. """
  114. raise NotImplementedError
  115. def verify(self, msg, key, sig):
  116. """
  117. Verifies that the specified digital signature is valid
  118. for the specified message and key values.
  119. """
  120. raise NotImplementedError
  121. @staticmethod
  122. def to_jwk(key_obj):
  123. """
  124. Serializes a given RSA key into a JWK
  125. """
  126. raise NotImplementedError
  127. @staticmethod
  128. def from_jwk(jwk):
  129. """
  130. Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object
  131. """
  132. raise NotImplementedError
  133. class NoneAlgorithm(Algorithm):
  134. """
  135. Placeholder for use when no signing or verification
  136. operations are required.
  137. """
  138. def prepare_key(self, key):
  139. if key == "":
  140. key = None
  141. if key is not None:
  142. raise InvalidKeyError('When alg = "none", key value must be None.')
  143. return key
  144. def sign(self, msg, key):
  145. return b""
  146. def verify(self, msg, key, sig):
  147. return False
  148. class HMACAlgorithm(Algorithm):
  149. """
  150. Performs signing and verification operations using HMAC
  151. and the specified hash function.
  152. """
  153. SHA256 = hashlib.sha256
  154. SHA384 = hashlib.sha384
  155. SHA512 = hashlib.sha512
  156. def __init__(self, hash_alg):
  157. self.hash_alg = hash_alg
  158. def prepare_key(self, key):
  159. key = force_bytes(key)
  160. if is_pem_format(key) or is_ssh_key(key):
  161. raise InvalidKeyError(
  162. "The specified key is an asymmetric key or x509 certificate and"
  163. " should not be used as an HMAC secret."
  164. )
  165. return key
  166. @staticmethod
  167. def to_jwk(key_obj):
  168. return json.dumps(
  169. {
  170. "k": base64url_encode(force_bytes(key_obj)).decode(),
  171. "kty": "oct",
  172. }
  173. )
  174. @staticmethod
  175. def from_jwk(jwk):
  176. try:
  177. if isinstance(jwk, str):
  178. obj = json.loads(jwk)
  179. elif isinstance(jwk, dict):
  180. obj = jwk
  181. else:
  182. raise ValueError
  183. except ValueError:
  184. raise InvalidKeyError("Key is not valid JSON")
  185. if obj.get("kty") != "oct":
  186. raise InvalidKeyError("Not an HMAC key")
  187. return base64url_decode(obj["k"])
  188. def sign(self, msg, key):
  189. return hmac.new(key, msg, self.hash_alg).digest()
  190. def verify(self, msg, key, sig):
  191. return hmac.compare_digest(sig, self.sign(msg, key))
  192. if has_crypto:
  193. class RSAAlgorithm(Algorithm):
  194. """
  195. Performs signing and verification operations using
  196. RSASSA-PKCS-v1_5 and the specified hash function.
  197. """
  198. SHA256 = hashes.SHA256
  199. SHA384 = hashes.SHA384
  200. SHA512 = hashes.SHA512
  201. def __init__(self, hash_alg):
  202. self.hash_alg = hash_alg
  203. def prepare_key(self, key):
  204. if isinstance(key, (RSAPrivateKey, RSAPublicKey)):
  205. return key
  206. if not isinstance(key, (bytes, str)):
  207. raise TypeError("Expecting a PEM-formatted key.")
  208. key = force_bytes(key)
  209. try:
  210. if key.startswith(b"ssh-rsa"):
  211. key = load_ssh_public_key(key)
  212. else:
  213. key = load_pem_private_key(key, password=None)
  214. except ValueError:
  215. key = load_pem_public_key(key)
  216. return key
  217. @staticmethod
  218. def to_jwk(key_obj):
  219. obj = None
  220. if getattr(key_obj, "private_numbers", None):
  221. # Private key
  222. numbers = key_obj.private_numbers()
  223. obj = {
  224. "kty": "RSA",
  225. "key_ops": ["sign"],
  226. "n": to_base64url_uint(numbers.public_numbers.n).decode(),
  227. "e": to_base64url_uint(numbers.public_numbers.e).decode(),
  228. "d": to_base64url_uint(numbers.d).decode(),
  229. "p": to_base64url_uint(numbers.p).decode(),
  230. "q": to_base64url_uint(numbers.q).decode(),
  231. "dp": to_base64url_uint(numbers.dmp1).decode(),
  232. "dq": to_base64url_uint(numbers.dmq1).decode(),
  233. "qi": to_base64url_uint(numbers.iqmp).decode(),
  234. }
  235. elif getattr(key_obj, "verify", None):
  236. # Public key
  237. numbers = key_obj.public_numbers()
  238. obj = {
  239. "kty": "RSA",
  240. "key_ops": ["verify"],
  241. "n": to_base64url_uint(numbers.n).decode(),
  242. "e": to_base64url_uint(numbers.e).decode(),
  243. }
  244. else:
  245. raise InvalidKeyError("Not a public or private key")
  246. return json.dumps(obj)
  247. @staticmethod
  248. def from_jwk(jwk):
  249. try:
  250. if isinstance(jwk, str):
  251. obj = json.loads(jwk)
  252. elif isinstance(jwk, dict):
  253. obj = jwk
  254. else:
  255. raise ValueError
  256. except ValueError:
  257. raise InvalidKeyError("Key is not valid JSON")
  258. if obj.get("kty") != "RSA":
  259. raise InvalidKeyError("Not an RSA key")
  260. if "d" in obj and "e" in obj and "n" in obj:
  261. # Private key
  262. if "oth" in obj:
  263. raise InvalidKeyError(
  264. "Unsupported RSA private key: > 2 primes not supported"
  265. )
  266. other_props = ["p", "q", "dp", "dq", "qi"]
  267. props_found = [prop in obj for prop in other_props]
  268. any_props_found = any(props_found)
  269. if any_props_found and not all(props_found):
  270. raise InvalidKeyError(
  271. "RSA key must include all parameters if any are present besides d"
  272. )
  273. public_numbers = RSAPublicNumbers(
  274. from_base64url_uint(obj["e"]),
  275. from_base64url_uint(obj["n"]),
  276. )
  277. if any_props_found:
  278. numbers = RSAPrivateNumbers(
  279. d=from_base64url_uint(obj["d"]),
  280. p=from_base64url_uint(obj["p"]),
  281. q=from_base64url_uint(obj["q"]),
  282. dmp1=from_base64url_uint(obj["dp"]),
  283. dmq1=from_base64url_uint(obj["dq"]),
  284. iqmp=from_base64url_uint(obj["qi"]),
  285. public_numbers=public_numbers,
  286. )
  287. else:
  288. d = from_base64url_uint(obj["d"])
  289. p, q = rsa_recover_prime_factors(
  290. public_numbers.n, d, public_numbers.e
  291. )
  292. numbers = RSAPrivateNumbers(
  293. d=d,
  294. p=p,
  295. q=q,
  296. dmp1=rsa_crt_dmp1(d, p),
  297. dmq1=rsa_crt_dmq1(d, q),
  298. iqmp=rsa_crt_iqmp(p, q),
  299. public_numbers=public_numbers,
  300. )
  301. return numbers.private_key()
  302. elif "n" in obj and "e" in obj:
  303. # Public key
  304. numbers = RSAPublicNumbers(
  305. from_base64url_uint(obj["e"]),
  306. from_base64url_uint(obj["n"]),
  307. )
  308. return numbers.public_key()
  309. else:
  310. raise InvalidKeyError("Not a public or private key")
  311. def sign(self, msg, key):
  312. return key.sign(msg, padding.PKCS1v15(), self.hash_alg())
  313. def verify(self, msg, key, sig):
  314. try:
  315. key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg())
  316. return True
  317. except InvalidSignature:
  318. return False
  319. class ECAlgorithm(Algorithm):
  320. """
  321. Performs signing and verification operations using
  322. ECDSA and the specified hash function
  323. """
  324. SHA256 = hashes.SHA256
  325. SHA384 = hashes.SHA384
  326. SHA512 = hashes.SHA512
  327. def __init__(self, hash_alg):
  328. self.hash_alg = hash_alg
  329. def prepare_key(self, key):
  330. if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)):
  331. return key
  332. if not isinstance(key, (bytes, str)):
  333. raise TypeError("Expecting a PEM-formatted key.")
  334. key = force_bytes(key)
  335. # Attempt to load key. We don't know if it's
  336. # a Signing Key or a Verifying Key, so we try
  337. # the Verifying Key first.
  338. try:
  339. if key.startswith(b"ecdsa-sha2-"):
  340. key = load_ssh_public_key(key)
  341. else:
  342. key = load_pem_public_key(key)
  343. except ValueError:
  344. key = load_pem_private_key(key, password=None)
  345. # Explicit check the key to prevent confusing errors from cryptography
  346. if not isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)):
  347. raise InvalidKeyError(
  348. "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for ECDSA algorithms"
  349. )
  350. return key
  351. def sign(self, msg, key):
  352. der_sig = key.sign(msg, ec.ECDSA(self.hash_alg()))
  353. return der_to_raw_signature(der_sig, key.curve)
  354. def verify(self, msg, key, sig):
  355. try:
  356. der_sig = raw_to_der_signature(sig, key.curve)
  357. except ValueError:
  358. return False
  359. try:
  360. if isinstance(key, EllipticCurvePrivateKey):
  361. key = key.public_key()
  362. key.verify(der_sig, msg, ec.ECDSA(self.hash_alg()))
  363. return True
  364. except InvalidSignature:
  365. return False
  366. @staticmethod
  367. def from_jwk(jwk):
  368. try:
  369. if isinstance(jwk, str):
  370. obj = json.loads(jwk)
  371. elif isinstance(jwk, dict):
  372. obj = jwk
  373. else:
  374. raise ValueError
  375. except ValueError:
  376. raise InvalidKeyError("Key is not valid JSON")
  377. if obj.get("kty") != "EC":
  378. raise InvalidKeyError("Not an Elliptic curve key")
  379. if "x" not in obj or "y" not in obj:
  380. raise InvalidKeyError("Not an Elliptic curve key")
  381. x = base64url_decode(obj.get("x"))
  382. y = base64url_decode(obj.get("y"))
  383. curve = obj.get("crv")
  384. if curve == "P-256":
  385. if len(x) == len(y) == 32:
  386. curve_obj = ec.SECP256R1()
  387. else:
  388. raise InvalidKeyError("Coords should be 32 bytes for curve P-256")
  389. elif curve == "P-384":
  390. if len(x) == len(y) == 48:
  391. curve_obj = ec.SECP384R1()
  392. else:
  393. raise InvalidKeyError("Coords should be 48 bytes for curve P-384")
  394. elif curve == "P-521":
  395. if len(x) == len(y) == 66:
  396. curve_obj = ec.SECP521R1()
  397. else:
  398. raise InvalidKeyError("Coords should be 66 bytes for curve P-521")
  399. elif curve == "secp256k1":
  400. if len(x) == len(y) == 32:
  401. curve_obj = ec.SECP256K1()
  402. else:
  403. raise InvalidKeyError(
  404. "Coords should be 32 bytes for curve secp256k1"
  405. )
  406. else:
  407. raise InvalidKeyError(f"Invalid curve: {curve}")
  408. public_numbers = ec.EllipticCurvePublicNumbers(
  409. x=int.from_bytes(x, byteorder="big"),
  410. y=int.from_bytes(y, byteorder="big"),
  411. curve=curve_obj,
  412. )
  413. if "d" not in obj:
  414. return public_numbers.public_key()
  415. d = base64url_decode(obj.get("d"))
  416. if len(d) != len(x):
  417. raise InvalidKeyError(
  418. "D should be {} bytes for curve {}", len(x), curve
  419. )
  420. return ec.EllipticCurvePrivateNumbers(
  421. int.from_bytes(d, byteorder="big"), public_numbers
  422. ).private_key()
  423. class RSAPSSAlgorithm(RSAAlgorithm):
  424. """
  425. Performs a signature using RSASSA-PSS with MGF1
  426. """
  427. def sign(self, msg, key):
  428. return key.sign(
  429. msg,
  430. padding.PSS(
  431. mgf=padding.MGF1(self.hash_alg()),
  432. salt_length=self.hash_alg.digest_size,
  433. ),
  434. self.hash_alg(),
  435. )
  436. def verify(self, msg, key, sig):
  437. try:
  438. key.verify(
  439. sig,
  440. msg,
  441. padding.PSS(
  442. mgf=padding.MGF1(self.hash_alg()),
  443. salt_length=self.hash_alg.digest_size,
  444. ),
  445. self.hash_alg(),
  446. )
  447. return True
  448. except InvalidSignature:
  449. return False
  450. class OKPAlgorithm(Algorithm):
  451. """
  452. Performs signing and verification operations using EdDSA
  453. This class requires ``cryptography>=2.6`` to be installed.
  454. """
  455. def __init__(self, **kwargs):
  456. pass
  457. def prepare_key(self, key):
  458. if isinstance(key, (bytes, str)):
  459. if isinstance(key, str):
  460. key = key.encode("utf-8")
  461. str_key = key.decode("utf-8")
  462. if "-----BEGIN PUBLIC" in str_key:
  463. key = load_pem_public_key(key)
  464. elif "-----BEGIN PRIVATE" in str_key:
  465. key = load_pem_private_key(key, password=None)
  466. elif str_key[0:4] == "ssh-":
  467. key = load_ssh_public_key(key)
  468. # Explicit check the key to prevent confusing errors from cryptography
  469. if not isinstance(
  470. key,
  471. (Ed25519PrivateKey, Ed25519PublicKey, Ed448PrivateKey, Ed448PublicKey),
  472. ):
  473. raise InvalidKeyError(
  474. "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for EdDSA algorithms"
  475. )
  476. return key
  477. def sign(self, msg, key):
  478. """
  479. Sign a message ``msg`` using the EdDSA private key ``key``
  480. :param str|bytes msg: Message to sign
  481. :param Ed25519PrivateKey}Ed448PrivateKey key: A :class:`.Ed25519PrivateKey`
  482. or :class:`.Ed448PrivateKey` iinstance
  483. :return bytes signature: The signature, as bytes
  484. """
  485. msg = bytes(msg, "utf-8") if type(msg) is not bytes else msg
  486. return key.sign(msg)
  487. def verify(self, msg, key, sig):
  488. """
  489. Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key``
  490. :param str|bytes sig: EdDSA signature to check ``msg`` against
  491. :param str|bytes msg: Message to sign
  492. :param Ed25519PrivateKey|Ed25519PublicKey|Ed448PrivateKey|Ed448PublicKey key:
  493. A private or public EdDSA key instance
  494. :return bool verified: True if signature is valid, False if not.
  495. """
  496. try:
  497. msg = bytes(msg, "utf-8") if type(msg) is not bytes else msg
  498. sig = bytes(sig, "utf-8") if type(sig) is not bytes else sig
  499. if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)):
  500. key = key.public_key()
  501. key.verify(sig, msg)
  502. return True # If no exception was raised, the signature is valid.
  503. except cryptography.exceptions.InvalidSignature:
  504. return False
  505. @staticmethod
  506. def to_jwk(key):
  507. if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)):
  508. x = key.public_bytes(
  509. encoding=Encoding.Raw,
  510. format=PublicFormat.Raw,
  511. )
  512. crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448"
  513. return json.dumps(
  514. {
  515. "x": base64url_encode(force_bytes(x)).decode(),
  516. "kty": "OKP",
  517. "crv": crv,
  518. }
  519. )
  520. if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)):
  521. d = key.private_bytes(
  522. encoding=Encoding.Raw,
  523. format=PrivateFormat.Raw,
  524. encryption_algorithm=NoEncryption(),
  525. )
  526. x = key.public_key().public_bytes(
  527. encoding=Encoding.Raw,
  528. format=PublicFormat.Raw,
  529. )
  530. crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448"
  531. return json.dumps(
  532. {
  533. "x": base64url_encode(force_bytes(x)).decode(),
  534. "d": base64url_encode(force_bytes(d)).decode(),
  535. "kty": "OKP",
  536. "crv": crv,
  537. }
  538. )
  539. raise InvalidKeyError("Not a public or private key")
  540. @staticmethod
  541. def from_jwk(jwk):
  542. try:
  543. if isinstance(jwk, str):
  544. obj = json.loads(jwk)
  545. elif isinstance(jwk, dict):
  546. obj = jwk
  547. else:
  548. raise ValueError
  549. except ValueError:
  550. raise InvalidKeyError("Key is not valid JSON")
  551. if obj.get("kty") != "OKP":
  552. raise InvalidKeyError("Not an Octet Key Pair")
  553. curve = obj.get("crv")
  554. if curve != "Ed25519" and curve != "Ed448":
  555. raise InvalidKeyError(f"Invalid curve: {curve}")
  556. if "x" not in obj:
  557. raise InvalidKeyError('OKP should have "x" parameter')
  558. x = base64url_decode(obj.get("x"))
  559. try:
  560. if "d" not in obj:
  561. if curve == "Ed25519":
  562. return Ed25519PublicKey.from_public_bytes(x)
  563. return Ed448PublicKey.from_public_bytes(x)
  564. d = base64url_decode(obj.get("d"))
  565. if curve == "Ed25519":
  566. return Ed25519PrivateKey.from_private_bytes(d)
  567. return Ed448PrivateKey.from_private_bytes(d)
  568. except ValueError as err:
  569. raise InvalidKeyError("Invalid key parameter") from err