123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757 |
- # This file is dual licensed under the terms of the Apache License, Version
- # 2.0, and the BSD License. See the LICENSE file in the root of this repository
- # for complete details.
- import binascii
- import os
- import re
- import typing
- from base64 import encodebytes as _base64_encode
- from cryptography import utils
- from cryptography.exceptions import UnsupportedAlgorithm
- from cryptography.hazmat.primitives.asymmetric import dsa, ec, ed25519, rsa
- from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
- from cryptography.hazmat.primitives.serialization import (
- Encoding,
- NoEncryption,
- PrivateFormat,
- PublicFormat,
- )
- try:
- from bcrypt import kdf as _bcrypt_kdf
- _bcrypt_supported = True
- except ImportError:
- _bcrypt_supported = False
- def _bcrypt_kdf(
- password: bytes,
- salt: bytes,
- desired_key_bytes: int,
- rounds: int,
- ignore_few_rounds: bool = False,
- ) -> bytes:
- raise UnsupportedAlgorithm("Need bcrypt module")
- _SSH_ED25519 = b"ssh-ed25519"
- _SSH_RSA = b"ssh-rsa"
- _SSH_DSA = b"ssh-dss"
- _ECDSA_NISTP256 = b"ecdsa-sha2-nistp256"
- _ECDSA_NISTP384 = b"ecdsa-sha2-nistp384"
- _ECDSA_NISTP521 = b"ecdsa-sha2-nistp521"
- _CERT_SUFFIX = b"-cert-v01@openssh.com"
- _SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)")
- _SK_MAGIC = b"openssh-key-v1\0"
- _SK_START = b"-----BEGIN OPENSSH PRIVATE KEY-----"
- _SK_END = b"-----END OPENSSH PRIVATE KEY-----"
- _BCRYPT = b"bcrypt"
- _NONE = b"none"
- _DEFAULT_CIPHER = b"aes256-ctr"
- _DEFAULT_ROUNDS = 16
- _MAX_PASSWORD = 72
- # re is only way to work on bytes-like data
- _PEM_RC = re.compile(_SK_START + b"(.*?)" + _SK_END, re.DOTALL)
- # padding for max blocksize
- _PADDING = memoryview(bytearray(range(1, 1 + 16)))
- # ciphers that are actually used in key wrapping
- _SSH_CIPHERS: typing.Dict[
- bytes,
- typing.Tuple[
- typing.Type[algorithms.AES],
- int,
- typing.Union[typing.Type[modes.CTR], typing.Type[modes.CBC]],
- int,
- ],
- ] = {
- b"aes256-ctr": (algorithms.AES, 32, modes.CTR, 16),
- b"aes256-cbc": (algorithms.AES, 32, modes.CBC, 16),
- }
- # map local curve name to key type
- _ECDSA_KEY_TYPE = {
- "secp256r1": _ECDSA_NISTP256,
- "secp384r1": _ECDSA_NISTP384,
- "secp521r1": _ECDSA_NISTP521,
- }
- def _ecdsa_key_type(public_key: ec.EllipticCurvePublicKey) -> bytes:
- """Return SSH key_type and curve_name for private key."""
- curve = public_key.curve
- if curve.name not in _ECDSA_KEY_TYPE:
- raise ValueError(
- f"Unsupported curve for ssh private key: {curve.name!r}"
- )
- return _ECDSA_KEY_TYPE[curve.name]
- def _ssh_pem_encode(
- data: bytes,
- prefix: bytes = _SK_START + b"\n",
- suffix: bytes = _SK_END + b"\n",
- ) -> bytes:
- return b"".join([prefix, _base64_encode(data), suffix])
- def _check_block_size(data: bytes, block_len: int) -> None:
- """Require data to be full blocks"""
- if not data or len(data) % block_len != 0:
- raise ValueError("Corrupt data: missing padding")
- def _check_empty(data: bytes) -> None:
- """All data should have been parsed."""
- if data:
- raise ValueError("Corrupt data: unparsed data")
- def _init_cipher(
- ciphername: bytes,
- password: typing.Optional[bytes],
- salt: bytes,
- rounds: int,
- ) -> Cipher[typing.Union[modes.CBC, modes.CTR]]:
- """Generate key + iv and return cipher."""
- if not password:
- raise ValueError("Key is password-protected.")
- algo, key_len, mode, iv_len = _SSH_CIPHERS[ciphername]
- seed = _bcrypt_kdf(password, salt, key_len + iv_len, rounds, True)
- return Cipher(algo(seed[:key_len]), mode(seed[key_len:]))
- def _get_u32(data: memoryview) -> typing.Tuple[int, memoryview]:
- """Uint32"""
- if len(data) < 4:
- raise ValueError("Invalid data")
- return int.from_bytes(data[:4], byteorder="big"), data[4:]
- def _get_u64(data: memoryview) -> typing.Tuple[int, memoryview]:
- """Uint64"""
- if len(data) < 8:
- raise ValueError("Invalid data")
- return int.from_bytes(data[:8], byteorder="big"), data[8:]
- def _get_sshstr(data: memoryview) -> typing.Tuple[memoryview, memoryview]:
- """Bytes with u32 length prefix"""
- n, data = _get_u32(data)
- if n > len(data):
- raise ValueError("Invalid data")
- return data[:n], data[n:]
- def _get_mpint(data: memoryview) -> typing.Tuple[int, memoryview]:
- """Big integer."""
- val, data = _get_sshstr(data)
- if val and val[0] > 0x7F:
- raise ValueError("Invalid data")
- return int.from_bytes(val, "big"), data
- def _to_mpint(val: int) -> bytes:
- """Storage format for signed bigint."""
- if val < 0:
- raise ValueError("negative mpint not allowed")
- if not val:
- return b""
- nbytes = (val.bit_length() + 8) // 8
- return utils.int_to_bytes(val, nbytes)
- class _FragList:
- """Build recursive structure without data copy."""
- flist: typing.List[bytes]
- def __init__(self, init: typing.List[bytes] = None) -> None:
- self.flist = []
- if init:
- self.flist.extend(init)
- def put_raw(self, val: bytes) -> None:
- """Add plain bytes"""
- self.flist.append(val)
- def put_u32(self, val: int) -> None:
- """Big-endian uint32"""
- self.flist.append(val.to_bytes(length=4, byteorder="big"))
- def put_sshstr(self, val: typing.Union[bytes, "_FragList"]) -> None:
- """Bytes prefixed with u32 length"""
- if isinstance(val, (bytes, memoryview, bytearray)):
- self.put_u32(len(val))
- self.flist.append(val)
- else:
- self.put_u32(val.size())
- self.flist.extend(val.flist)
- def put_mpint(self, val: int) -> None:
- """Big-endian bigint prefixed with u32 length"""
- self.put_sshstr(_to_mpint(val))
- def size(self) -> int:
- """Current number of bytes"""
- return sum(map(len, self.flist))
- def render(self, dstbuf: memoryview, pos: int = 0) -> int:
- """Write into bytearray"""
- for frag in self.flist:
- flen = len(frag)
- start, pos = pos, pos + flen
- dstbuf[start:pos] = frag
- return pos
- def tobytes(self) -> bytes:
- """Return as bytes"""
- buf = memoryview(bytearray(self.size()))
- self.render(buf)
- return buf.tobytes()
- class _SSHFormatRSA:
- """Format for RSA keys.
- Public:
- mpint e, n
- Private:
- mpint n, e, d, iqmp, p, q
- """
- def get_public(self, data: memoryview):
- """RSA public fields"""
- e, data = _get_mpint(data)
- n, data = _get_mpint(data)
- return (e, n), data
- def load_public(
- self, data: memoryview
- ) -> typing.Tuple[rsa.RSAPublicKey, memoryview]:
- """Make RSA public key from data."""
- (e, n), data = self.get_public(data)
- public_numbers = rsa.RSAPublicNumbers(e, n)
- public_key = public_numbers.public_key()
- return public_key, data
- def load_private(
- self, data: memoryview, pubfields
- ) -> typing.Tuple[rsa.RSAPrivateKey, memoryview]:
- """Make RSA private key from data."""
- n, data = _get_mpint(data)
- e, data = _get_mpint(data)
- d, data = _get_mpint(data)
- iqmp, data = _get_mpint(data)
- p, data = _get_mpint(data)
- q, data = _get_mpint(data)
- if (e, n) != pubfields:
- raise ValueError("Corrupt data: rsa field mismatch")
- dmp1 = rsa.rsa_crt_dmp1(d, p)
- dmq1 = rsa.rsa_crt_dmq1(d, q)
- public_numbers = rsa.RSAPublicNumbers(e, n)
- private_numbers = rsa.RSAPrivateNumbers(
- p, q, d, dmp1, dmq1, iqmp, public_numbers
- )
- private_key = private_numbers.private_key()
- return private_key, data
- def encode_public(
- self, public_key: rsa.RSAPublicKey, f_pub: _FragList
- ) -> None:
- """Write RSA public key"""
- pubn = public_key.public_numbers()
- f_pub.put_mpint(pubn.e)
- f_pub.put_mpint(pubn.n)
- def encode_private(
- self, private_key: rsa.RSAPrivateKey, f_priv: _FragList
- ) -> None:
- """Write RSA private key"""
- private_numbers = private_key.private_numbers()
- public_numbers = private_numbers.public_numbers
- f_priv.put_mpint(public_numbers.n)
- f_priv.put_mpint(public_numbers.e)
- f_priv.put_mpint(private_numbers.d)
- f_priv.put_mpint(private_numbers.iqmp)
- f_priv.put_mpint(private_numbers.p)
- f_priv.put_mpint(private_numbers.q)
- class _SSHFormatDSA:
- """Format for DSA keys.
- Public:
- mpint p, q, g, y
- Private:
- mpint p, q, g, y, x
- """
- def get_public(
- self, data: memoryview
- ) -> typing.Tuple[typing.Tuple, memoryview]:
- """DSA public fields"""
- p, data = _get_mpint(data)
- q, data = _get_mpint(data)
- g, data = _get_mpint(data)
- y, data = _get_mpint(data)
- return (p, q, g, y), data
- def load_public(
- self, data: memoryview
- ) -> typing.Tuple[dsa.DSAPublicKey, memoryview]:
- """Make DSA public key from data."""
- (p, q, g, y), data = self.get_public(data)
- parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
- public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
- self._validate(public_numbers)
- public_key = public_numbers.public_key()
- return public_key, data
- def load_private(
- self, data: memoryview, pubfields
- ) -> typing.Tuple[dsa.DSAPrivateKey, memoryview]:
- """Make DSA private key from data."""
- (p, q, g, y), data = self.get_public(data)
- x, data = _get_mpint(data)
- if (p, q, g, y) != pubfields:
- raise ValueError("Corrupt data: dsa field mismatch")
- parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
- public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
- self._validate(public_numbers)
- private_numbers = dsa.DSAPrivateNumbers(x, public_numbers)
- private_key = private_numbers.private_key()
- return private_key, data
- def encode_public(
- self, public_key: dsa.DSAPublicKey, f_pub: _FragList
- ) -> None:
- """Write DSA public key"""
- public_numbers = public_key.public_numbers()
- parameter_numbers = public_numbers.parameter_numbers
- self._validate(public_numbers)
- f_pub.put_mpint(parameter_numbers.p)
- f_pub.put_mpint(parameter_numbers.q)
- f_pub.put_mpint(parameter_numbers.g)
- f_pub.put_mpint(public_numbers.y)
- def encode_private(
- self, private_key: dsa.DSAPrivateKey, f_priv: _FragList
- ) -> None:
- """Write DSA private key"""
- self.encode_public(private_key.public_key(), f_priv)
- f_priv.put_mpint(private_key.private_numbers().x)
- def _validate(self, public_numbers: dsa.DSAPublicNumbers) -> None:
- parameter_numbers = public_numbers.parameter_numbers
- if parameter_numbers.p.bit_length() != 1024:
- raise ValueError("SSH supports only 1024 bit DSA keys")
- class _SSHFormatECDSA:
- """Format for ECDSA keys.
- Public:
- str curve
- bytes point
- Private:
- str curve
- bytes point
- mpint secret
- """
- def __init__(self, ssh_curve_name: bytes, curve: ec.EllipticCurve):
- self.ssh_curve_name = ssh_curve_name
- self.curve = curve
- def get_public(
- self, data: memoryview
- ) -> typing.Tuple[typing.Tuple, memoryview]:
- """ECDSA public fields"""
- curve, data = _get_sshstr(data)
- point, data = _get_sshstr(data)
- if curve != self.ssh_curve_name:
- raise ValueError("Curve name mismatch")
- if point[0] != 4:
- raise NotImplementedError("Need uncompressed point")
- return (curve, point), data
- def load_public(
- self, data: memoryview
- ) -> typing.Tuple[ec.EllipticCurvePublicKey, memoryview]:
- """Make ECDSA public key from data."""
- (curve_name, point), data = self.get_public(data)
- public_key = ec.EllipticCurvePublicKey.from_encoded_point(
- self.curve, point.tobytes()
- )
- return public_key, data
- def load_private(
- self, data: memoryview, pubfields
- ) -> typing.Tuple[ec.EllipticCurvePrivateKey, memoryview]:
- """Make ECDSA private key from data."""
- (curve_name, point), data = self.get_public(data)
- secret, data = _get_mpint(data)
- if (curve_name, point) != pubfields:
- raise ValueError("Corrupt data: ecdsa field mismatch")
- private_key = ec.derive_private_key(secret, self.curve)
- return private_key, data
- def encode_public(
- self, public_key: ec.EllipticCurvePublicKey, f_pub: _FragList
- ) -> None:
- """Write ECDSA public key"""
- point = public_key.public_bytes(
- Encoding.X962, PublicFormat.UncompressedPoint
- )
- f_pub.put_sshstr(self.ssh_curve_name)
- f_pub.put_sshstr(point)
- def encode_private(
- self, private_key: ec.EllipticCurvePrivateKey, f_priv: _FragList
- ) -> None:
- """Write ECDSA private key"""
- public_key = private_key.public_key()
- private_numbers = private_key.private_numbers()
- self.encode_public(public_key, f_priv)
- f_priv.put_mpint(private_numbers.private_value)
- class _SSHFormatEd25519:
- """Format for Ed25519 keys.
- Public:
- bytes point
- Private:
- bytes point
- bytes secret_and_point
- """
- def get_public(
- self, data: memoryview
- ) -> typing.Tuple[typing.Tuple, memoryview]:
- """Ed25519 public fields"""
- point, data = _get_sshstr(data)
- return (point,), data
- def load_public(
- self, data: memoryview
- ) -> typing.Tuple[ed25519.Ed25519PublicKey, memoryview]:
- """Make Ed25519 public key from data."""
- (point,), data = self.get_public(data)
- public_key = ed25519.Ed25519PublicKey.from_public_bytes(
- point.tobytes()
- )
- return public_key, data
- def load_private(
- self, data: memoryview, pubfields
- ) -> typing.Tuple[ed25519.Ed25519PrivateKey, memoryview]:
- """Make Ed25519 private key from data."""
- (point,), data = self.get_public(data)
- keypair, data = _get_sshstr(data)
- secret = keypair[:32]
- point2 = keypair[32:]
- if point != point2 or (point,) != pubfields:
- raise ValueError("Corrupt data: ed25519 field mismatch")
- private_key = ed25519.Ed25519PrivateKey.from_private_bytes(secret)
- return private_key, data
- def encode_public(
- self, public_key: ed25519.Ed25519PublicKey, f_pub: _FragList
- ) -> None:
- """Write Ed25519 public key"""
- raw_public_key = public_key.public_bytes(
- Encoding.Raw, PublicFormat.Raw
- )
- f_pub.put_sshstr(raw_public_key)
- def encode_private(
- self, private_key: ed25519.Ed25519PrivateKey, f_priv: _FragList
- ) -> None:
- """Write Ed25519 private key"""
- public_key = private_key.public_key()
- raw_private_key = private_key.private_bytes(
- Encoding.Raw, PrivateFormat.Raw, NoEncryption()
- )
- raw_public_key = public_key.public_bytes(
- Encoding.Raw, PublicFormat.Raw
- )
- f_keypair = _FragList([raw_private_key, raw_public_key])
- self.encode_public(public_key, f_priv)
- f_priv.put_sshstr(f_keypair)
- _KEY_FORMATS = {
- _SSH_RSA: _SSHFormatRSA(),
- _SSH_DSA: _SSHFormatDSA(),
- _SSH_ED25519: _SSHFormatEd25519(),
- _ECDSA_NISTP256: _SSHFormatECDSA(b"nistp256", ec.SECP256R1()),
- _ECDSA_NISTP384: _SSHFormatECDSA(b"nistp384", ec.SECP384R1()),
- _ECDSA_NISTP521: _SSHFormatECDSA(b"nistp521", ec.SECP521R1()),
- }
- def _lookup_kformat(key_type: bytes):
- """Return valid format or throw error"""
- if not isinstance(key_type, bytes):
- key_type = memoryview(key_type).tobytes()
- if key_type in _KEY_FORMATS:
- return _KEY_FORMATS[key_type]
- raise UnsupportedAlgorithm(f"Unsupported key type: {key_type!r}")
- _SSH_PRIVATE_KEY_TYPES = typing.Union[
- ec.EllipticCurvePrivateKey,
- rsa.RSAPrivateKey,
- dsa.DSAPrivateKey,
- ed25519.Ed25519PrivateKey,
- ]
- def load_ssh_private_key(
- data: bytes,
- password: typing.Optional[bytes],
- backend: typing.Any = None,
- ) -> _SSH_PRIVATE_KEY_TYPES:
- """Load private key from OpenSSH custom encoding."""
- utils._check_byteslike("data", data)
- if password is not None:
- utils._check_bytes("password", password)
- m = _PEM_RC.search(data)
- if not m:
- raise ValueError("Not OpenSSH private key format")
- p1 = m.start(1)
- p2 = m.end(1)
- data = binascii.a2b_base64(memoryview(data)[p1:p2])
- if not data.startswith(_SK_MAGIC):
- raise ValueError("Not OpenSSH private key format")
- data = memoryview(data)[len(_SK_MAGIC) :]
- # parse header
- ciphername, data = _get_sshstr(data)
- kdfname, data = _get_sshstr(data)
- kdfoptions, data = _get_sshstr(data)
- nkeys, data = _get_u32(data)
- if nkeys != 1:
- raise ValueError("Only one key supported")
- # load public key data
- pubdata, data = _get_sshstr(data)
- pub_key_type, pubdata = _get_sshstr(pubdata)
- kformat = _lookup_kformat(pub_key_type)
- pubfields, pubdata = kformat.get_public(pubdata)
- _check_empty(pubdata)
- # load secret data
- edata, data = _get_sshstr(data)
- _check_empty(data)
- if (ciphername, kdfname) != (_NONE, _NONE):
- ciphername_bytes = ciphername.tobytes()
- if ciphername_bytes not in _SSH_CIPHERS:
- raise UnsupportedAlgorithm(
- f"Unsupported cipher: {ciphername_bytes!r}"
- )
- if kdfname != _BCRYPT:
- raise UnsupportedAlgorithm(f"Unsupported KDF: {kdfname!r}")
- blklen = _SSH_CIPHERS[ciphername_bytes][3]
- _check_block_size(edata, blklen)
- salt, kbuf = _get_sshstr(kdfoptions)
- rounds, kbuf = _get_u32(kbuf)
- _check_empty(kbuf)
- ciph = _init_cipher(ciphername_bytes, password, salt.tobytes(), rounds)
- edata = memoryview(ciph.decryptor().update(edata))
- else:
- blklen = 8
- _check_block_size(edata, blklen)
- ck1, edata = _get_u32(edata)
- ck2, edata = _get_u32(edata)
- if ck1 != ck2:
- raise ValueError("Corrupt data: broken checksum")
- # load per-key struct
- key_type, edata = _get_sshstr(edata)
- if key_type != pub_key_type:
- raise ValueError("Corrupt data: key type mismatch")
- private_key, edata = kformat.load_private(edata, pubfields)
- comment, edata = _get_sshstr(edata)
- # yes, SSH does padding check *after* all other parsing is done.
- # need to follow as it writes zero-byte padding too.
- if edata != _PADDING[: len(edata)]:
- raise ValueError("Corrupt data: invalid padding")
- return private_key
- def serialize_ssh_private_key(
- private_key: _SSH_PRIVATE_KEY_TYPES,
- password: typing.Optional[bytes] = None,
- ) -> bytes:
- """Serialize private key with OpenSSH custom encoding."""
- if password is not None:
- utils._check_bytes("password", password)
- if password and len(password) > _MAX_PASSWORD:
- raise ValueError(
- "Passwords longer than 72 bytes are not supported by "
- "OpenSSH private key format"
- )
- if isinstance(private_key, ec.EllipticCurvePrivateKey):
- key_type = _ecdsa_key_type(private_key.public_key())
- elif isinstance(private_key, rsa.RSAPrivateKey):
- key_type = _SSH_RSA
- elif isinstance(private_key, dsa.DSAPrivateKey):
- key_type = _SSH_DSA
- elif isinstance(private_key, ed25519.Ed25519PrivateKey):
- key_type = _SSH_ED25519
- else:
- raise ValueError("Unsupported key type")
- kformat = _lookup_kformat(key_type)
- # setup parameters
- f_kdfoptions = _FragList()
- if password:
- ciphername = _DEFAULT_CIPHER
- blklen = _SSH_CIPHERS[ciphername][3]
- kdfname = _BCRYPT
- rounds = _DEFAULT_ROUNDS
- salt = os.urandom(16)
- f_kdfoptions.put_sshstr(salt)
- f_kdfoptions.put_u32(rounds)
- ciph = _init_cipher(ciphername, password, salt, rounds)
- else:
- ciphername = kdfname = _NONE
- blklen = 8
- ciph = None
- nkeys = 1
- checkval = os.urandom(4)
- comment = b""
- # encode public and private parts together
- f_public_key = _FragList()
- f_public_key.put_sshstr(key_type)
- kformat.encode_public(private_key.public_key(), f_public_key)
- f_secrets = _FragList([checkval, checkval])
- f_secrets.put_sshstr(key_type)
- kformat.encode_private(private_key, f_secrets)
- f_secrets.put_sshstr(comment)
- f_secrets.put_raw(_PADDING[: blklen - (f_secrets.size() % blklen)])
- # top-level structure
- f_main = _FragList()
- f_main.put_raw(_SK_MAGIC)
- f_main.put_sshstr(ciphername)
- f_main.put_sshstr(kdfname)
- f_main.put_sshstr(f_kdfoptions)
- f_main.put_u32(nkeys)
- f_main.put_sshstr(f_public_key)
- f_main.put_sshstr(f_secrets)
- # copy result info bytearray
- slen = f_secrets.size()
- mlen = f_main.size()
- buf = memoryview(bytearray(mlen + blklen))
- f_main.render(buf)
- ofs = mlen - slen
- # encrypt in-place
- if ciph is not None:
- ciph.encryptor().update_into(buf[ofs:mlen], buf[ofs:])
- txt = _ssh_pem_encode(buf[:mlen])
- buf[ofs:mlen] = bytearray(slen)
- return txt
- _SSH_PUBLIC_KEY_TYPES = typing.Union[
- ec.EllipticCurvePublicKey,
- rsa.RSAPublicKey,
- dsa.DSAPublicKey,
- ed25519.Ed25519PublicKey,
- ]
- def load_ssh_public_key(
- data: bytes, backend: typing.Any = None
- ) -> _SSH_PUBLIC_KEY_TYPES:
- """Load public key from OpenSSH one-line format."""
- utils._check_byteslike("data", data)
- m = _SSH_PUBKEY_RC.match(data)
- if not m:
- raise ValueError("Invalid line format")
- key_type = orig_key_type = m.group(1)
- key_body = m.group(2)
- with_cert = False
- if _CERT_SUFFIX == key_type[-len(_CERT_SUFFIX) :]:
- with_cert = True
- key_type = key_type[: -len(_CERT_SUFFIX)]
- kformat = _lookup_kformat(key_type)
- try:
- rest = memoryview(binascii.a2b_base64(key_body))
- except (TypeError, binascii.Error):
- raise ValueError("Invalid key format")
- inner_key_type, rest = _get_sshstr(rest)
- if inner_key_type != orig_key_type:
- raise ValueError("Invalid key format")
- if with_cert:
- nonce, rest = _get_sshstr(rest)
- public_key, rest = kformat.load_public(rest)
- if with_cert:
- serial, rest = _get_u64(rest)
- cctype, rest = _get_u32(rest)
- key_id, rest = _get_sshstr(rest)
- principals, rest = _get_sshstr(rest)
- valid_after, rest = _get_u64(rest)
- valid_before, rest = _get_u64(rest)
- crit_options, rest = _get_sshstr(rest)
- extensions, rest = _get_sshstr(rest)
- reserved, rest = _get_sshstr(rest)
- sig_key, rest = _get_sshstr(rest)
- signature, rest = _get_sshstr(rest)
- _check_empty(rest)
- return public_key
- def serialize_ssh_public_key(public_key: _SSH_PUBLIC_KEY_TYPES) -> bytes:
- """One-line public key format for OpenSSH"""
- if isinstance(public_key, ec.EllipticCurvePublicKey):
- key_type = _ecdsa_key_type(public_key)
- elif isinstance(public_key, rsa.RSAPublicKey):
- key_type = _SSH_RSA
- elif isinstance(public_key, dsa.DSAPublicKey):
- key_type = _SSH_DSA
- elif isinstance(public_key, ed25519.Ed25519PublicKey):
- key_type = _SSH_ED25519
- else:
- raise ValueError("Unsupported key type")
- kformat = _lookup_kformat(key_type)
- f_pub = _FragList()
- f_pub.put_sshstr(key_type)
- kformat.encode_public(public_key, f_pub)
- pub = binascii.b2a_base64(f_pub.tobytes()).strip()
- return b"".join([key_type, b" ", pub])
|