rsa.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. # This file is dual licensed under the terms of the Apache License, Version
  2. # 2.0, and the BSD License. See the LICENSE file in the root of this repository
  3. # for complete details.
  4. import abc
  5. import typing
  6. from math import gcd
  7. from cryptography.hazmat.primitives import _serialization, hashes
  8. from cryptography.hazmat.primitives._asymmetric import AsymmetricPadding
  9. from cryptography.hazmat.primitives.asymmetric import (
  10. utils as asym_utils,
  11. )
  12. class RSAPrivateKey(metaclass=abc.ABCMeta):
  13. @abc.abstractmethod
  14. def decrypt(self, ciphertext: bytes, padding: AsymmetricPadding) -> bytes:
  15. """
  16. Decrypts the provided ciphertext.
  17. """
  18. @abc.abstractproperty
  19. def key_size(self) -> int:
  20. """
  21. The bit length of the public modulus.
  22. """
  23. @abc.abstractmethod
  24. def public_key(self) -> "RSAPublicKey":
  25. """
  26. The RSAPublicKey associated with this private key.
  27. """
  28. @abc.abstractmethod
  29. def sign(
  30. self,
  31. data: bytes,
  32. padding: AsymmetricPadding,
  33. algorithm: typing.Union[asym_utils.Prehashed, hashes.HashAlgorithm],
  34. ) -> bytes:
  35. """
  36. Signs the data.
  37. """
  38. @abc.abstractmethod
  39. def private_numbers(self) -> "RSAPrivateNumbers":
  40. """
  41. Returns an RSAPrivateNumbers.
  42. """
  43. @abc.abstractmethod
  44. def private_bytes(
  45. self,
  46. encoding: _serialization.Encoding,
  47. format: _serialization.PrivateFormat,
  48. encryption_algorithm: _serialization.KeySerializationEncryption,
  49. ) -> bytes:
  50. """
  51. Returns the key serialized as bytes.
  52. """
  53. RSAPrivateKeyWithSerialization = RSAPrivateKey
  54. class RSAPublicKey(metaclass=abc.ABCMeta):
  55. @abc.abstractmethod
  56. def encrypt(self, plaintext: bytes, padding: AsymmetricPadding) -> bytes:
  57. """
  58. Encrypts the given plaintext.
  59. """
  60. @abc.abstractproperty
  61. def key_size(self) -> int:
  62. """
  63. The bit length of the public modulus.
  64. """
  65. @abc.abstractmethod
  66. def public_numbers(self) -> "RSAPublicNumbers":
  67. """
  68. Returns an RSAPublicNumbers
  69. """
  70. @abc.abstractmethod
  71. def public_bytes(
  72. self,
  73. encoding: _serialization.Encoding,
  74. format: _serialization.PublicFormat,
  75. ) -> bytes:
  76. """
  77. Returns the key serialized as bytes.
  78. """
  79. @abc.abstractmethod
  80. def verify(
  81. self,
  82. signature: bytes,
  83. data: bytes,
  84. padding: AsymmetricPadding,
  85. algorithm: typing.Union[asym_utils.Prehashed, hashes.HashAlgorithm],
  86. ) -> None:
  87. """
  88. Verifies the signature of the data.
  89. """
  90. @abc.abstractmethod
  91. def recover_data_from_signature(
  92. self,
  93. signature: bytes,
  94. padding: AsymmetricPadding,
  95. algorithm: typing.Optional[hashes.HashAlgorithm],
  96. ) -> bytes:
  97. """
  98. Recovers the original data from the signature.
  99. """
  100. RSAPublicKeyWithSerialization = RSAPublicKey
  101. def generate_private_key(
  102. public_exponent: int,
  103. key_size: int,
  104. backend: typing.Any = None,
  105. ) -> RSAPrivateKey:
  106. from cryptography.hazmat.backends.openssl.backend import backend as ossl
  107. _verify_rsa_parameters(public_exponent, key_size)
  108. return ossl.generate_rsa_private_key(public_exponent, key_size)
  109. def _verify_rsa_parameters(public_exponent: int, key_size: int) -> None:
  110. if public_exponent not in (3, 65537):
  111. raise ValueError(
  112. "public_exponent must be either 3 (for legacy compatibility) or "
  113. "65537. Almost everyone should choose 65537 here!"
  114. )
  115. if key_size < 512:
  116. raise ValueError("key_size must be at least 512-bits.")
  117. def _check_private_key_components(
  118. p: int,
  119. q: int,
  120. private_exponent: int,
  121. dmp1: int,
  122. dmq1: int,
  123. iqmp: int,
  124. public_exponent: int,
  125. modulus: int,
  126. ) -> None:
  127. if modulus < 3:
  128. raise ValueError("modulus must be >= 3.")
  129. if p >= modulus:
  130. raise ValueError("p must be < modulus.")
  131. if q >= modulus:
  132. raise ValueError("q must be < modulus.")
  133. if dmp1 >= modulus:
  134. raise ValueError("dmp1 must be < modulus.")
  135. if dmq1 >= modulus:
  136. raise ValueError("dmq1 must be < modulus.")
  137. if iqmp >= modulus:
  138. raise ValueError("iqmp must be < modulus.")
  139. if private_exponent >= modulus:
  140. raise ValueError("private_exponent must be < modulus.")
  141. if public_exponent < 3 or public_exponent >= modulus:
  142. raise ValueError("public_exponent must be >= 3 and < modulus.")
  143. if public_exponent & 1 == 0:
  144. raise ValueError("public_exponent must be odd.")
  145. if dmp1 & 1 == 0:
  146. raise ValueError("dmp1 must be odd.")
  147. if dmq1 & 1 == 0:
  148. raise ValueError("dmq1 must be odd.")
  149. if p * q != modulus:
  150. raise ValueError("p*q must equal modulus.")
  151. def _check_public_key_components(e: int, n: int) -> None:
  152. if n < 3:
  153. raise ValueError("n must be >= 3.")
  154. if e < 3 or e >= n:
  155. raise ValueError("e must be >= 3 and < n.")
  156. if e & 1 == 0:
  157. raise ValueError("e must be odd.")
  158. def _modinv(e: int, m: int) -> int:
  159. """
  160. Modular Multiplicative Inverse. Returns x such that: (x*e) mod m == 1
  161. """
  162. x1, x2 = 1, 0
  163. a, b = e, m
  164. while b > 0:
  165. q, r = divmod(a, b)
  166. xn = x1 - q * x2
  167. a, b, x1, x2 = b, r, x2, xn
  168. return x1 % m
  169. def rsa_crt_iqmp(p: int, q: int) -> int:
  170. """
  171. Compute the CRT (q ** -1) % p value from RSA primes p and q.
  172. """
  173. return _modinv(q, p)
  174. def rsa_crt_dmp1(private_exponent: int, p: int) -> int:
  175. """
  176. Compute the CRT private_exponent % (p - 1) value from the RSA
  177. private_exponent (d) and p.
  178. """
  179. return private_exponent % (p - 1)
  180. def rsa_crt_dmq1(private_exponent: int, q: int) -> int:
  181. """
  182. Compute the CRT private_exponent % (q - 1) value from the RSA
  183. private_exponent (d) and q.
  184. """
  185. return private_exponent % (q - 1)
  186. # Controls the number of iterations rsa_recover_prime_factors will perform
  187. # to obtain the prime factors. Each iteration increments by 2 so the actual
  188. # maximum attempts is half this number.
  189. _MAX_RECOVERY_ATTEMPTS = 1000
  190. def rsa_recover_prime_factors(
  191. n: int, e: int, d: int
  192. ) -> typing.Tuple[int, int]:
  193. """
  194. Compute factors p and q from the private exponent d. We assume that n has
  195. no more than two factors. This function is adapted from code in PyCrypto.
  196. """
  197. # See 8.2.2(i) in Handbook of Applied Cryptography.
  198. ktot = d * e - 1
  199. # The quantity d*e-1 is a multiple of phi(n), even,
  200. # and can be represented as t*2^s.
  201. t = ktot
  202. while t % 2 == 0:
  203. t = t // 2
  204. # Cycle through all multiplicative inverses in Zn.
  205. # The algorithm is non-deterministic, but there is a 50% chance
  206. # any candidate a leads to successful factoring.
  207. # See "Digitalized Signatures and Public Key Functions as Intractable
  208. # as Factorization", M. Rabin, 1979
  209. spotted = False
  210. a = 2
  211. while not spotted and a < _MAX_RECOVERY_ATTEMPTS:
  212. k = t
  213. # Cycle through all values a^{t*2^i}=a^k
  214. while k < ktot:
  215. cand = pow(a, k, n)
  216. # Check if a^k is a non-trivial root of unity (mod n)
  217. if cand != 1 and cand != (n - 1) and pow(cand, 2, n) == 1:
  218. # We have found a number such that (cand-1)(cand+1)=0 (mod n).
  219. # Either of the terms divides n.
  220. p = gcd(cand + 1, n)
  221. spotted = True
  222. break
  223. k *= 2
  224. # This value was not any good... let's try another!
  225. a += 2
  226. if not spotted:
  227. raise ValueError("Unable to compute factors p and q from exponent d.")
  228. # Found !
  229. q, r = divmod(n, p)
  230. assert r == 0
  231. p, q = sorted((p, q), reverse=True)
  232. return (p, q)
  233. class RSAPrivateNumbers:
  234. def __init__(
  235. self,
  236. p: int,
  237. q: int,
  238. d: int,
  239. dmp1: int,
  240. dmq1: int,
  241. iqmp: int,
  242. public_numbers: "RSAPublicNumbers",
  243. ):
  244. if (
  245. not isinstance(p, int)
  246. or not isinstance(q, int)
  247. or not isinstance(d, int)
  248. or not isinstance(dmp1, int)
  249. or not isinstance(dmq1, int)
  250. or not isinstance(iqmp, int)
  251. ):
  252. raise TypeError(
  253. "RSAPrivateNumbers p, q, d, dmp1, dmq1, iqmp arguments must"
  254. " all be an integers."
  255. )
  256. if not isinstance(public_numbers, RSAPublicNumbers):
  257. raise TypeError(
  258. "RSAPrivateNumbers public_numbers must be an RSAPublicNumbers"
  259. " instance."
  260. )
  261. self._p = p
  262. self._q = q
  263. self._d = d
  264. self._dmp1 = dmp1
  265. self._dmq1 = dmq1
  266. self._iqmp = iqmp
  267. self._public_numbers = public_numbers
  268. @property
  269. def p(self) -> int:
  270. return self._p
  271. @property
  272. def q(self) -> int:
  273. return self._q
  274. @property
  275. def d(self) -> int:
  276. return self._d
  277. @property
  278. def dmp1(self) -> int:
  279. return self._dmp1
  280. @property
  281. def dmq1(self) -> int:
  282. return self._dmq1
  283. @property
  284. def iqmp(self) -> int:
  285. return self._iqmp
  286. @property
  287. def public_numbers(self) -> "RSAPublicNumbers":
  288. return self._public_numbers
  289. def private_key(self, backend: typing.Any = None) -> RSAPrivateKey:
  290. from cryptography.hazmat.backends.openssl.backend import (
  291. backend as ossl,
  292. )
  293. return ossl.load_rsa_private_numbers(self)
  294. def __eq__(self, other: object) -> bool:
  295. if not isinstance(other, RSAPrivateNumbers):
  296. return NotImplemented
  297. return (
  298. self.p == other.p
  299. and self.q == other.q
  300. and self.d == other.d
  301. and self.dmp1 == other.dmp1
  302. and self.dmq1 == other.dmq1
  303. and self.iqmp == other.iqmp
  304. and self.public_numbers == other.public_numbers
  305. )
  306. def __hash__(self) -> int:
  307. return hash(
  308. (
  309. self.p,
  310. self.q,
  311. self.d,
  312. self.dmp1,
  313. self.dmq1,
  314. self.iqmp,
  315. self.public_numbers,
  316. )
  317. )
  318. class RSAPublicNumbers:
  319. def __init__(self, e: int, n: int):
  320. if not isinstance(e, int) or not isinstance(n, int):
  321. raise TypeError("RSAPublicNumbers arguments must be integers.")
  322. self._e = e
  323. self._n = n
  324. @property
  325. def e(self) -> int:
  326. return self._e
  327. @property
  328. def n(self) -> int:
  329. return self._n
  330. def public_key(self, backend: typing.Any = None) -> RSAPublicKey:
  331. from cryptography.hazmat.backends.openssl.backend import (
  332. backend as ossl,
  333. )
  334. return ossl.load_rsa_public_numbers(self)
  335. def __repr__(self) -> str:
  336. return "<RSAPublicNumbers(e={0.e}, n={0.n})>".format(self)
  337. def __eq__(self, other: object) -> bool:
  338. if not isinstance(other, RSAPublicNumbers):
  339. return NotImplemented
  340. return self.e == other.e and self.n == other.n
  341. def __hash__(self) -> int:
  342. return hash((self.e, self.n))