ssh.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757
  1. # This file is dual licensed under the terms of the Apache License, Version
  2. # 2.0, and the BSD License. See the LICENSE file in the root of this repository
  3. # for complete details.
  4. import binascii
  5. import os
  6. import re
  7. import typing
  8. from base64 import encodebytes as _base64_encode
  9. from cryptography import utils
  10. from cryptography.exceptions import UnsupportedAlgorithm
  11. from cryptography.hazmat.primitives.asymmetric import dsa, ec, ed25519, rsa
  12. from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
  13. from cryptography.hazmat.primitives.serialization import (
  14. Encoding,
  15. NoEncryption,
  16. PrivateFormat,
  17. PublicFormat,
  18. )
  19. try:
  20. from bcrypt import kdf as _bcrypt_kdf
  21. _bcrypt_supported = True
  22. except ImportError:
  23. _bcrypt_supported = False
  24. def _bcrypt_kdf(
  25. password: bytes,
  26. salt: bytes,
  27. desired_key_bytes: int,
  28. rounds: int,
  29. ignore_few_rounds: bool = False,
  30. ) -> bytes:
  31. raise UnsupportedAlgorithm("Need bcrypt module")
  32. _SSH_ED25519 = b"ssh-ed25519"
  33. _SSH_RSA = b"ssh-rsa"
  34. _SSH_DSA = b"ssh-dss"
  35. _ECDSA_NISTP256 = b"ecdsa-sha2-nistp256"
  36. _ECDSA_NISTP384 = b"ecdsa-sha2-nistp384"
  37. _ECDSA_NISTP521 = b"ecdsa-sha2-nistp521"
  38. _CERT_SUFFIX = b"-cert-v01@openssh.com"
  39. _SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)")
  40. _SK_MAGIC = b"openssh-key-v1\0"
  41. _SK_START = b"-----BEGIN OPENSSH PRIVATE KEY-----"
  42. _SK_END = b"-----END OPENSSH PRIVATE KEY-----"
  43. _BCRYPT = b"bcrypt"
  44. _NONE = b"none"
  45. _DEFAULT_CIPHER = b"aes256-ctr"
  46. _DEFAULT_ROUNDS = 16
  47. _MAX_PASSWORD = 72
  48. # re is only way to work on bytes-like data
  49. _PEM_RC = re.compile(_SK_START + b"(.*?)" + _SK_END, re.DOTALL)
  50. # padding for max blocksize
  51. _PADDING = memoryview(bytearray(range(1, 1 + 16)))
  52. # ciphers that are actually used in key wrapping
  53. _SSH_CIPHERS: typing.Dict[
  54. bytes,
  55. typing.Tuple[
  56. typing.Type[algorithms.AES],
  57. int,
  58. typing.Union[typing.Type[modes.CTR], typing.Type[modes.CBC]],
  59. int,
  60. ],
  61. ] = {
  62. b"aes256-ctr": (algorithms.AES, 32, modes.CTR, 16),
  63. b"aes256-cbc": (algorithms.AES, 32, modes.CBC, 16),
  64. }
  65. # map local curve name to key type
  66. _ECDSA_KEY_TYPE = {
  67. "secp256r1": _ECDSA_NISTP256,
  68. "secp384r1": _ECDSA_NISTP384,
  69. "secp521r1": _ECDSA_NISTP521,
  70. }
  71. def _ecdsa_key_type(public_key: ec.EllipticCurvePublicKey) -> bytes:
  72. """Return SSH key_type and curve_name for private key."""
  73. curve = public_key.curve
  74. if curve.name not in _ECDSA_KEY_TYPE:
  75. raise ValueError(
  76. f"Unsupported curve for ssh private key: {curve.name!r}"
  77. )
  78. return _ECDSA_KEY_TYPE[curve.name]
  79. def _ssh_pem_encode(
  80. data: bytes,
  81. prefix: bytes = _SK_START + b"\n",
  82. suffix: bytes = _SK_END + b"\n",
  83. ) -> bytes:
  84. return b"".join([prefix, _base64_encode(data), suffix])
  85. def _check_block_size(data: bytes, block_len: int) -> None:
  86. """Require data to be full blocks"""
  87. if not data or len(data) % block_len != 0:
  88. raise ValueError("Corrupt data: missing padding")
  89. def _check_empty(data: bytes) -> None:
  90. """All data should have been parsed."""
  91. if data:
  92. raise ValueError("Corrupt data: unparsed data")
  93. def _init_cipher(
  94. ciphername: bytes,
  95. password: typing.Optional[bytes],
  96. salt: bytes,
  97. rounds: int,
  98. ) -> Cipher[typing.Union[modes.CBC, modes.CTR]]:
  99. """Generate key + iv and return cipher."""
  100. if not password:
  101. raise ValueError("Key is password-protected.")
  102. algo, key_len, mode, iv_len = _SSH_CIPHERS[ciphername]
  103. seed = _bcrypt_kdf(password, salt, key_len + iv_len, rounds, True)
  104. return Cipher(algo(seed[:key_len]), mode(seed[key_len:]))
  105. def _get_u32(data: memoryview) -> typing.Tuple[int, memoryview]:
  106. """Uint32"""
  107. if len(data) < 4:
  108. raise ValueError("Invalid data")
  109. return int.from_bytes(data[:4], byteorder="big"), data[4:]
  110. def _get_u64(data: memoryview) -> typing.Tuple[int, memoryview]:
  111. """Uint64"""
  112. if len(data) < 8:
  113. raise ValueError("Invalid data")
  114. return int.from_bytes(data[:8], byteorder="big"), data[8:]
  115. def _get_sshstr(data: memoryview) -> typing.Tuple[memoryview, memoryview]:
  116. """Bytes with u32 length prefix"""
  117. n, data = _get_u32(data)
  118. if n > len(data):
  119. raise ValueError("Invalid data")
  120. return data[:n], data[n:]
  121. def _get_mpint(data: memoryview) -> typing.Tuple[int, memoryview]:
  122. """Big integer."""
  123. val, data = _get_sshstr(data)
  124. if val and val[0] > 0x7F:
  125. raise ValueError("Invalid data")
  126. return int.from_bytes(val, "big"), data
  127. def _to_mpint(val: int) -> bytes:
  128. """Storage format for signed bigint."""
  129. if val < 0:
  130. raise ValueError("negative mpint not allowed")
  131. if not val:
  132. return b""
  133. nbytes = (val.bit_length() + 8) // 8
  134. return utils.int_to_bytes(val, nbytes)
  135. class _FragList:
  136. """Build recursive structure without data copy."""
  137. flist: typing.List[bytes]
  138. def __init__(self, init: typing.List[bytes] = None) -> None:
  139. self.flist = []
  140. if init:
  141. self.flist.extend(init)
  142. def put_raw(self, val: bytes) -> None:
  143. """Add plain bytes"""
  144. self.flist.append(val)
  145. def put_u32(self, val: int) -> None:
  146. """Big-endian uint32"""
  147. self.flist.append(val.to_bytes(length=4, byteorder="big"))
  148. def put_sshstr(self, val: typing.Union[bytes, "_FragList"]) -> None:
  149. """Bytes prefixed with u32 length"""
  150. if isinstance(val, (bytes, memoryview, bytearray)):
  151. self.put_u32(len(val))
  152. self.flist.append(val)
  153. else:
  154. self.put_u32(val.size())
  155. self.flist.extend(val.flist)
  156. def put_mpint(self, val: int) -> None:
  157. """Big-endian bigint prefixed with u32 length"""
  158. self.put_sshstr(_to_mpint(val))
  159. def size(self) -> int:
  160. """Current number of bytes"""
  161. return sum(map(len, self.flist))
  162. def render(self, dstbuf: memoryview, pos: int = 0) -> int:
  163. """Write into bytearray"""
  164. for frag in self.flist:
  165. flen = len(frag)
  166. start, pos = pos, pos + flen
  167. dstbuf[start:pos] = frag
  168. return pos
  169. def tobytes(self) -> bytes:
  170. """Return as bytes"""
  171. buf = memoryview(bytearray(self.size()))
  172. self.render(buf)
  173. return buf.tobytes()
  174. class _SSHFormatRSA:
  175. """Format for RSA keys.
  176. Public:
  177. mpint e, n
  178. Private:
  179. mpint n, e, d, iqmp, p, q
  180. """
  181. def get_public(self, data: memoryview):
  182. """RSA public fields"""
  183. e, data = _get_mpint(data)
  184. n, data = _get_mpint(data)
  185. return (e, n), data
  186. def load_public(
  187. self, data: memoryview
  188. ) -> typing.Tuple[rsa.RSAPublicKey, memoryview]:
  189. """Make RSA public key from data."""
  190. (e, n), data = self.get_public(data)
  191. public_numbers = rsa.RSAPublicNumbers(e, n)
  192. public_key = public_numbers.public_key()
  193. return public_key, data
  194. def load_private(
  195. self, data: memoryview, pubfields
  196. ) -> typing.Tuple[rsa.RSAPrivateKey, memoryview]:
  197. """Make RSA private key from data."""
  198. n, data = _get_mpint(data)
  199. e, data = _get_mpint(data)
  200. d, data = _get_mpint(data)
  201. iqmp, data = _get_mpint(data)
  202. p, data = _get_mpint(data)
  203. q, data = _get_mpint(data)
  204. if (e, n) != pubfields:
  205. raise ValueError("Corrupt data: rsa field mismatch")
  206. dmp1 = rsa.rsa_crt_dmp1(d, p)
  207. dmq1 = rsa.rsa_crt_dmq1(d, q)
  208. public_numbers = rsa.RSAPublicNumbers(e, n)
  209. private_numbers = rsa.RSAPrivateNumbers(
  210. p, q, d, dmp1, dmq1, iqmp, public_numbers
  211. )
  212. private_key = private_numbers.private_key()
  213. return private_key, data
  214. def encode_public(
  215. self, public_key: rsa.RSAPublicKey, f_pub: _FragList
  216. ) -> None:
  217. """Write RSA public key"""
  218. pubn = public_key.public_numbers()
  219. f_pub.put_mpint(pubn.e)
  220. f_pub.put_mpint(pubn.n)
  221. def encode_private(
  222. self, private_key: rsa.RSAPrivateKey, f_priv: _FragList
  223. ) -> None:
  224. """Write RSA private key"""
  225. private_numbers = private_key.private_numbers()
  226. public_numbers = private_numbers.public_numbers
  227. f_priv.put_mpint(public_numbers.n)
  228. f_priv.put_mpint(public_numbers.e)
  229. f_priv.put_mpint(private_numbers.d)
  230. f_priv.put_mpint(private_numbers.iqmp)
  231. f_priv.put_mpint(private_numbers.p)
  232. f_priv.put_mpint(private_numbers.q)
  233. class _SSHFormatDSA:
  234. """Format for DSA keys.
  235. Public:
  236. mpint p, q, g, y
  237. Private:
  238. mpint p, q, g, y, x
  239. """
  240. def get_public(
  241. self, data: memoryview
  242. ) -> typing.Tuple[typing.Tuple, memoryview]:
  243. """DSA public fields"""
  244. p, data = _get_mpint(data)
  245. q, data = _get_mpint(data)
  246. g, data = _get_mpint(data)
  247. y, data = _get_mpint(data)
  248. return (p, q, g, y), data
  249. def load_public(
  250. self, data: memoryview
  251. ) -> typing.Tuple[dsa.DSAPublicKey, memoryview]:
  252. """Make DSA public key from data."""
  253. (p, q, g, y), data = self.get_public(data)
  254. parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
  255. public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
  256. self._validate(public_numbers)
  257. public_key = public_numbers.public_key()
  258. return public_key, data
  259. def load_private(
  260. self, data: memoryview, pubfields
  261. ) -> typing.Tuple[dsa.DSAPrivateKey, memoryview]:
  262. """Make DSA private key from data."""
  263. (p, q, g, y), data = self.get_public(data)
  264. x, data = _get_mpint(data)
  265. if (p, q, g, y) != pubfields:
  266. raise ValueError("Corrupt data: dsa field mismatch")
  267. parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
  268. public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
  269. self._validate(public_numbers)
  270. private_numbers = dsa.DSAPrivateNumbers(x, public_numbers)
  271. private_key = private_numbers.private_key()
  272. return private_key, data
  273. def encode_public(
  274. self, public_key: dsa.DSAPublicKey, f_pub: _FragList
  275. ) -> None:
  276. """Write DSA public key"""
  277. public_numbers = public_key.public_numbers()
  278. parameter_numbers = public_numbers.parameter_numbers
  279. self._validate(public_numbers)
  280. f_pub.put_mpint(parameter_numbers.p)
  281. f_pub.put_mpint(parameter_numbers.q)
  282. f_pub.put_mpint(parameter_numbers.g)
  283. f_pub.put_mpint(public_numbers.y)
  284. def encode_private(
  285. self, private_key: dsa.DSAPrivateKey, f_priv: _FragList
  286. ) -> None:
  287. """Write DSA private key"""
  288. self.encode_public(private_key.public_key(), f_priv)
  289. f_priv.put_mpint(private_key.private_numbers().x)
  290. def _validate(self, public_numbers: dsa.DSAPublicNumbers) -> None:
  291. parameter_numbers = public_numbers.parameter_numbers
  292. if parameter_numbers.p.bit_length() != 1024:
  293. raise ValueError("SSH supports only 1024 bit DSA keys")
  294. class _SSHFormatECDSA:
  295. """Format for ECDSA keys.
  296. Public:
  297. str curve
  298. bytes point
  299. Private:
  300. str curve
  301. bytes point
  302. mpint secret
  303. """
  304. def __init__(self, ssh_curve_name: bytes, curve: ec.EllipticCurve):
  305. self.ssh_curve_name = ssh_curve_name
  306. self.curve = curve
  307. def get_public(
  308. self, data: memoryview
  309. ) -> typing.Tuple[typing.Tuple, memoryview]:
  310. """ECDSA public fields"""
  311. curve, data = _get_sshstr(data)
  312. point, data = _get_sshstr(data)
  313. if curve != self.ssh_curve_name:
  314. raise ValueError("Curve name mismatch")
  315. if point[0] != 4:
  316. raise NotImplementedError("Need uncompressed point")
  317. return (curve, point), data
  318. def load_public(
  319. self, data: memoryview
  320. ) -> typing.Tuple[ec.EllipticCurvePublicKey, memoryview]:
  321. """Make ECDSA public key from data."""
  322. (curve_name, point), data = self.get_public(data)
  323. public_key = ec.EllipticCurvePublicKey.from_encoded_point(
  324. self.curve, point.tobytes()
  325. )
  326. return public_key, data
  327. def load_private(
  328. self, data: memoryview, pubfields
  329. ) -> typing.Tuple[ec.EllipticCurvePrivateKey, memoryview]:
  330. """Make ECDSA private key from data."""
  331. (curve_name, point), data = self.get_public(data)
  332. secret, data = _get_mpint(data)
  333. if (curve_name, point) != pubfields:
  334. raise ValueError("Corrupt data: ecdsa field mismatch")
  335. private_key = ec.derive_private_key(secret, self.curve)
  336. return private_key, data
  337. def encode_public(
  338. self, public_key: ec.EllipticCurvePublicKey, f_pub: _FragList
  339. ) -> None:
  340. """Write ECDSA public key"""
  341. point = public_key.public_bytes(
  342. Encoding.X962, PublicFormat.UncompressedPoint
  343. )
  344. f_pub.put_sshstr(self.ssh_curve_name)
  345. f_pub.put_sshstr(point)
  346. def encode_private(
  347. self, private_key: ec.EllipticCurvePrivateKey, f_priv: _FragList
  348. ) -> None:
  349. """Write ECDSA private key"""
  350. public_key = private_key.public_key()
  351. private_numbers = private_key.private_numbers()
  352. self.encode_public(public_key, f_priv)
  353. f_priv.put_mpint(private_numbers.private_value)
  354. class _SSHFormatEd25519:
  355. """Format for Ed25519 keys.
  356. Public:
  357. bytes point
  358. Private:
  359. bytes point
  360. bytes secret_and_point
  361. """
  362. def get_public(
  363. self, data: memoryview
  364. ) -> typing.Tuple[typing.Tuple, memoryview]:
  365. """Ed25519 public fields"""
  366. point, data = _get_sshstr(data)
  367. return (point,), data
  368. def load_public(
  369. self, data: memoryview
  370. ) -> typing.Tuple[ed25519.Ed25519PublicKey, memoryview]:
  371. """Make Ed25519 public key from data."""
  372. (point,), data = self.get_public(data)
  373. public_key = ed25519.Ed25519PublicKey.from_public_bytes(
  374. point.tobytes()
  375. )
  376. return public_key, data
  377. def load_private(
  378. self, data: memoryview, pubfields
  379. ) -> typing.Tuple[ed25519.Ed25519PrivateKey, memoryview]:
  380. """Make Ed25519 private key from data."""
  381. (point,), data = self.get_public(data)
  382. keypair, data = _get_sshstr(data)
  383. secret = keypair[:32]
  384. point2 = keypair[32:]
  385. if point != point2 or (point,) != pubfields:
  386. raise ValueError("Corrupt data: ed25519 field mismatch")
  387. private_key = ed25519.Ed25519PrivateKey.from_private_bytes(secret)
  388. return private_key, data
  389. def encode_public(
  390. self, public_key: ed25519.Ed25519PublicKey, f_pub: _FragList
  391. ) -> None:
  392. """Write Ed25519 public key"""
  393. raw_public_key = public_key.public_bytes(
  394. Encoding.Raw, PublicFormat.Raw
  395. )
  396. f_pub.put_sshstr(raw_public_key)
  397. def encode_private(
  398. self, private_key: ed25519.Ed25519PrivateKey, f_priv: _FragList
  399. ) -> None:
  400. """Write Ed25519 private key"""
  401. public_key = private_key.public_key()
  402. raw_private_key = private_key.private_bytes(
  403. Encoding.Raw, PrivateFormat.Raw, NoEncryption()
  404. )
  405. raw_public_key = public_key.public_bytes(
  406. Encoding.Raw, PublicFormat.Raw
  407. )
  408. f_keypair = _FragList([raw_private_key, raw_public_key])
  409. self.encode_public(public_key, f_priv)
  410. f_priv.put_sshstr(f_keypair)
  411. _KEY_FORMATS = {
  412. _SSH_RSA: _SSHFormatRSA(),
  413. _SSH_DSA: _SSHFormatDSA(),
  414. _SSH_ED25519: _SSHFormatEd25519(),
  415. _ECDSA_NISTP256: _SSHFormatECDSA(b"nistp256", ec.SECP256R1()),
  416. _ECDSA_NISTP384: _SSHFormatECDSA(b"nistp384", ec.SECP384R1()),
  417. _ECDSA_NISTP521: _SSHFormatECDSA(b"nistp521", ec.SECP521R1()),
  418. }
  419. def _lookup_kformat(key_type: bytes):
  420. """Return valid format or throw error"""
  421. if not isinstance(key_type, bytes):
  422. key_type = memoryview(key_type).tobytes()
  423. if key_type in _KEY_FORMATS:
  424. return _KEY_FORMATS[key_type]
  425. raise UnsupportedAlgorithm(f"Unsupported key type: {key_type!r}")
  426. _SSH_PRIVATE_KEY_TYPES = typing.Union[
  427. ec.EllipticCurvePrivateKey,
  428. rsa.RSAPrivateKey,
  429. dsa.DSAPrivateKey,
  430. ed25519.Ed25519PrivateKey,
  431. ]
  432. def load_ssh_private_key(
  433. data: bytes,
  434. password: typing.Optional[bytes],
  435. backend: typing.Any = None,
  436. ) -> _SSH_PRIVATE_KEY_TYPES:
  437. """Load private key from OpenSSH custom encoding."""
  438. utils._check_byteslike("data", data)
  439. if password is not None:
  440. utils._check_bytes("password", password)
  441. m = _PEM_RC.search(data)
  442. if not m:
  443. raise ValueError("Not OpenSSH private key format")
  444. p1 = m.start(1)
  445. p2 = m.end(1)
  446. data = binascii.a2b_base64(memoryview(data)[p1:p2])
  447. if not data.startswith(_SK_MAGIC):
  448. raise ValueError("Not OpenSSH private key format")
  449. data = memoryview(data)[len(_SK_MAGIC) :]
  450. # parse header
  451. ciphername, data = _get_sshstr(data)
  452. kdfname, data = _get_sshstr(data)
  453. kdfoptions, data = _get_sshstr(data)
  454. nkeys, data = _get_u32(data)
  455. if nkeys != 1:
  456. raise ValueError("Only one key supported")
  457. # load public key data
  458. pubdata, data = _get_sshstr(data)
  459. pub_key_type, pubdata = _get_sshstr(pubdata)
  460. kformat = _lookup_kformat(pub_key_type)
  461. pubfields, pubdata = kformat.get_public(pubdata)
  462. _check_empty(pubdata)
  463. # load secret data
  464. edata, data = _get_sshstr(data)
  465. _check_empty(data)
  466. if (ciphername, kdfname) != (_NONE, _NONE):
  467. ciphername_bytes = ciphername.tobytes()
  468. if ciphername_bytes not in _SSH_CIPHERS:
  469. raise UnsupportedAlgorithm(
  470. f"Unsupported cipher: {ciphername_bytes!r}"
  471. )
  472. if kdfname != _BCRYPT:
  473. raise UnsupportedAlgorithm(f"Unsupported KDF: {kdfname!r}")
  474. blklen = _SSH_CIPHERS[ciphername_bytes][3]
  475. _check_block_size(edata, blklen)
  476. salt, kbuf = _get_sshstr(kdfoptions)
  477. rounds, kbuf = _get_u32(kbuf)
  478. _check_empty(kbuf)
  479. ciph = _init_cipher(ciphername_bytes, password, salt.tobytes(), rounds)
  480. edata = memoryview(ciph.decryptor().update(edata))
  481. else:
  482. blklen = 8
  483. _check_block_size(edata, blklen)
  484. ck1, edata = _get_u32(edata)
  485. ck2, edata = _get_u32(edata)
  486. if ck1 != ck2:
  487. raise ValueError("Corrupt data: broken checksum")
  488. # load per-key struct
  489. key_type, edata = _get_sshstr(edata)
  490. if key_type != pub_key_type:
  491. raise ValueError("Corrupt data: key type mismatch")
  492. private_key, edata = kformat.load_private(edata, pubfields)
  493. comment, edata = _get_sshstr(edata)
  494. # yes, SSH does padding check *after* all other parsing is done.
  495. # need to follow as it writes zero-byte padding too.
  496. if edata != _PADDING[: len(edata)]:
  497. raise ValueError("Corrupt data: invalid padding")
  498. return private_key
  499. def serialize_ssh_private_key(
  500. private_key: _SSH_PRIVATE_KEY_TYPES,
  501. password: typing.Optional[bytes] = None,
  502. ) -> bytes:
  503. """Serialize private key with OpenSSH custom encoding."""
  504. if password is not None:
  505. utils._check_bytes("password", password)
  506. if password and len(password) > _MAX_PASSWORD:
  507. raise ValueError(
  508. "Passwords longer than 72 bytes are not supported by "
  509. "OpenSSH private key format"
  510. )
  511. if isinstance(private_key, ec.EllipticCurvePrivateKey):
  512. key_type = _ecdsa_key_type(private_key.public_key())
  513. elif isinstance(private_key, rsa.RSAPrivateKey):
  514. key_type = _SSH_RSA
  515. elif isinstance(private_key, dsa.DSAPrivateKey):
  516. key_type = _SSH_DSA
  517. elif isinstance(private_key, ed25519.Ed25519PrivateKey):
  518. key_type = _SSH_ED25519
  519. else:
  520. raise ValueError("Unsupported key type")
  521. kformat = _lookup_kformat(key_type)
  522. # setup parameters
  523. f_kdfoptions = _FragList()
  524. if password:
  525. ciphername = _DEFAULT_CIPHER
  526. blklen = _SSH_CIPHERS[ciphername][3]
  527. kdfname = _BCRYPT
  528. rounds = _DEFAULT_ROUNDS
  529. salt = os.urandom(16)
  530. f_kdfoptions.put_sshstr(salt)
  531. f_kdfoptions.put_u32(rounds)
  532. ciph = _init_cipher(ciphername, password, salt, rounds)
  533. else:
  534. ciphername = kdfname = _NONE
  535. blklen = 8
  536. ciph = None
  537. nkeys = 1
  538. checkval = os.urandom(4)
  539. comment = b""
  540. # encode public and private parts together
  541. f_public_key = _FragList()
  542. f_public_key.put_sshstr(key_type)
  543. kformat.encode_public(private_key.public_key(), f_public_key)
  544. f_secrets = _FragList([checkval, checkval])
  545. f_secrets.put_sshstr(key_type)
  546. kformat.encode_private(private_key, f_secrets)
  547. f_secrets.put_sshstr(comment)
  548. f_secrets.put_raw(_PADDING[: blklen - (f_secrets.size() % blklen)])
  549. # top-level structure
  550. f_main = _FragList()
  551. f_main.put_raw(_SK_MAGIC)
  552. f_main.put_sshstr(ciphername)
  553. f_main.put_sshstr(kdfname)
  554. f_main.put_sshstr(f_kdfoptions)
  555. f_main.put_u32(nkeys)
  556. f_main.put_sshstr(f_public_key)
  557. f_main.put_sshstr(f_secrets)
  558. # copy result info bytearray
  559. slen = f_secrets.size()
  560. mlen = f_main.size()
  561. buf = memoryview(bytearray(mlen + blklen))
  562. f_main.render(buf)
  563. ofs = mlen - slen
  564. # encrypt in-place
  565. if ciph is not None:
  566. ciph.encryptor().update_into(buf[ofs:mlen], buf[ofs:])
  567. txt = _ssh_pem_encode(buf[:mlen])
  568. buf[ofs:mlen] = bytearray(slen)
  569. return txt
  570. _SSH_PUBLIC_KEY_TYPES = typing.Union[
  571. ec.EllipticCurvePublicKey,
  572. rsa.RSAPublicKey,
  573. dsa.DSAPublicKey,
  574. ed25519.Ed25519PublicKey,
  575. ]
  576. def load_ssh_public_key(
  577. data: bytes, backend: typing.Any = None
  578. ) -> _SSH_PUBLIC_KEY_TYPES:
  579. """Load public key from OpenSSH one-line format."""
  580. utils._check_byteslike("data", data)
  581. m = _SSH_PUBKEY_RC.match(data)
  582. if not m:
  583. raise ValueError("Invalid line format")
  584. key_type = orig_key_type = m.group(1)
  585. key_body = m.group(2)
  586. with_cert = False
  587. if _CERT_SUFFIX == key_type[-len(_CERT_SUFFIX) :]:
  588. with_cert = True
  589. key_type = key_type[: -len(_CERT_SUFFIX)]
  590. kformat = _lookup_kformat(key_type)
  591. try:
  592. rest = memoryview(binascii.a2b_base64(key_body))
  593. except (TypeError, binascii.Error):
  594. raise ValueError("Invalid key format")
  595. inner_key_type, rest = _get_sshstr(rest)
  596. if inner_key_type != orig_key_type:
  597. raise ValueError("Invalid key format")
  598. if with_cert:
  599. nonce, rest = _get_sshstr(rest)
  600. public_key, rest = kformat.load_public(rest)
  601. if with_cert:
  602. serial, rest = _get_u64(rest)
  603. cctype, rest = _get_u32(rest)
  604. key_id, rest = _get_sshstr(rest)
  605. principals, rest = _get_sshstr(rest)
  606. valid_after, rest = _get_u64(rest)
  607. valid_before, rest = _get_u64(rest)
  608. crit_options, rest = _get_sshstr(rest)
  609. extensions, rest = _get_sshstr(rest)
  610. reserved, rest = _get_sshstr(rest)
  611. sig_key, rest = _get_sshstr(rest)
  612. signature, rest = _get_sshstr(rest)
  613. _check_empty(rest)
  614. return public_key
  615. def serialize_ssh_public_key(public_key: _SSH_PUBLIC_KEY_TYPES) -> bytes:
  616. """One-line public key format for OpenSSH"""
  617. if isinstance(public_key, ec.EllipticCurvePublicKey):
  618. key_type = _ecdsa_key_type(public_key)
  619. elif isinstance(public_key, rsa.RSAPublicKey):
  620. key_type = _SSH_RSA
  621. elif isinstance(public_key, dsa.DSAPublicKey):
  622. key_type = _SSH_DSA
  623. elif isinstance(public_key, ed25519.Ed25519PublicKey):
  624. key_type = _SSH_ED25519
  625. else:
  626. raise ValueError("Unsupported key type")
  627. kformat = _lookup_kformat(key_type)
  628. f_pub = _FragList()
  629. f_pub.put_sshstr(key_type)
  630. kformat.encode_public(public_key, f_pub)
  631. pub = binascii.b2a_base64(f_pub.tobytes()).strip()
  632. return b"".join([key_type, b" ", pub])