123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333 |
- # Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
- #
- # This file is part of paramiko.
- #
- # Paramiko is free software; you can redistribute it and/or modify it under the
- # terms of the GNU Lesser General Public License as published by the Free
- # Software Foundation; either version 2.1 of the License, or (at your option)
- # any later version.
- #
- # Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
- # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- # A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- # details.
- #
- # You should have received a copy of the GNU Lesser General Public License
- # along with Paramiko; if not, write to the Free Software Foundation, Inc.,
- # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
- """
- ECDSA keys
- """
- from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm
- from cryptography.hazmat.backends import default_backend
- from cryptography.hazmat.primitives import hashes, serialization
- from cryptography.hazmat.primitives.asymmetric import ec
- from cryptography.hazmat.primitives.asymmetric.utils import (
- decode_dss_signature,
- encode_dss_signature,
- )
- from paramiko.common import four_byte
- from paramiko.message import Message
- from paramiko.pkey import PKey
- from paramiko.ssh_exception import SSHException
- from paramiko.util import deflate_long
- class _ECDSACurve(object):
- """
- Represents a specific ECDSA Curve (nistp256, nistp384, etc).
- Handles the generation of the key format identifier and the selection of
- the proper hash function. Also grabs the proper curve from the 'ecdsa'
- package.
- """
- def __init__(self, curve_class, nist_name):
- self.nist_name = nist_name
- self.key_length = curve_class.key_size
- # Defined in RFC 5656 6.2
- self.key_format_identifier = "ecdsa-sha2-" + self.nist_name
- # Defined in RFC 5656 6.2.1
- if self.key_length <= 256:
- self.hash_object = hashes.SHA256
- elif self.key_length <= 384:
- self.hash_object = hashes.SHA384
- else:
- self.hash_object = hashes.SHA512
- self.curve_class = curve_class
- class _ECDSACurveSet(object):
- """
- A collection to hold the ECDSA curves. Allows querying by oid and by key
- format identifier. The two ways in which ECDSAKey needs to be able to look
- up curves.
- """
- def __init__(self, ecdsa_curves):
- self.ecdsa_curves = ecdsa_curves
- def get_key_format_identifier_list(self):
- return [curve.key_format_identifier for curve in self.ecdsa_curves]
- def get_by_curve_class(self, curve_class):
- for curve in self.ecdsa_curves:
- if curve.curve_class == curve_class:
- return curve
- def get_by_key_format_identifier(self, key_format_identifier):
- for curve in self.ecdsa_curves:
- if curve.key_format_identifier == key_format_identifier:
- return curve
- def get_by_key_length(self, key_length):
- for curve in self.ecdsa_curves:
- if curve.key_length == key_length:
- return curve
- class ECDSAKey(PKey):
- """
- Representation of an ECDSA key which can be used to sign and verify SSH2
- data.
- """
- _ECDSA_CURVES = _ECDSACurveSet(
- [
- _ECDSACurve(ec.SECP256R1, "nistp256"),
- _ECDSACurve(ec.SECP384R1, "nistp384"),
- _ECDSACurve(ec.SECP521R1, "nistp521"),
- ]
- )
- def __init__(
- self,
- msg=None,
- data=None,
- filename=None,
- password=None,
- vals=None,
- file_obj=None,
- validate_point=True,
- ):
- self.verifying_key = None
- self.signing_key = None
- self.public_blob = None
- if file_obj is not None:
- self._from_private_key(file_obj, password)
- return
- if filename is not None:
- self._from_private_key_file(filename, password)
- return
- if (msg is None) and (data is not None):
- msg = Message(data)
- if vals is not None:
- self.signing_key, self.verifying_key = vals
- c_class = self.signing_key.curve.__class__
- self.ecdsa_curve = self._ECDSA_CURVES.get_by_curve_class(c_class)
- else:
- # Must set ecdsa_curve first; subroutines called herein may need to
- # spit out our get_name(), which relies on this.
- key_type = msg.get_text()
- # But this also means we need to hand it a real key/curve
- # identifier, so strip out any cert business. (NOTE: could push
- # that into _ECDSACurveSet.get_by_key_format_identifier(), but it
- # feels more correct to do it here?)
- suffix = "-cert-v01@openssh.com"
- if key_type.endswith(suffix):
- key_type = key_type[: -len(suffix)]
- self.ecdsa_curve = self._ECDSA_CURVES.get_by_key_format_identifier(
- key_type
- )
- key_types = self._ECDSA_CURVES.get_key_format_identifier_list()
- cert_types = [
- "{}-cert-v01@openssh.com".format(x) for x in key_types
- ]
- self._check_type_and_load_cert(
- msg=msg, key_type=key_types, cert_type=cert_types
- )
- curvename = msg.get_text()
- if curvename != self.ecdsa_curve.nist_name:
- raise SSHException(
- "Can't handle curve of type {}".format(curvename)
- )
- pointinfo = msg.get_binary()
- try:
- key = ec.EllipticCurvePublicKey.from_encoded_point(
- self.ecdsa_curve.curve_class(), pointinfo
- )
- self.verifying_key = key
- except ValueError:
- raise SSHException("Invalid public key")
- @classmethod
- def supported_key_format_identifiers(cls):
- return cls._ECDSA_CURVES.get_key_format_identifier_list()
- def asbytes(self):
- key = self.verifying_key
- m = Message()
- m.add_string(self.ecdsa_curve.key_format_identifier)
- m.add_string(self.ecdsa_curve.nist_name)
- numbers = key.public_numbers()
- key_size_bytes = (key.curve.key_size + 7) // 8
- x_bytes = deflate_long(numbers.x, add_sign_padding=False)
- x_bytes = b"\x00" * (key_size_bytes - len(x_bytes)) + x_bytes
- y_bytes = deflate_long(numbers.y, add_sign_padding=False)
- y_bytes = b"\x00" * (key_size_bytes - len(y_bytes)) + y_bytes
- point_str = four_byte + x_bytes + y_bytes
- m.add_string(point_str)
- return m.asbytes()
- def __str__(self):
- return self.asbytes()
- @property
- def _fields(self):
- return (
- self.get_name(),
- self.verifying_key.public_numbers().x,
- self.verifying_key.public_numbers().y,
- )
- def get_name(self):
- return self.ecdsa_curve.key_format_identifier
- def get_bits(self):
- return self.ecdsa_curve.key_length
- def can_sign(self):
- return self.signing_key is not None
- def sign_ssh_data(self, data, algorithm=None):
- ecdsa = ec.ECDSA(self.ecdsa_curve.hash_object())
- sig = self.signing_key.sign(data, ecdsa)
- r, s = decode_dss_signature(sig)
- m = Message()
- m.add_string(self.ecdsa_curve.key_format_identifier)
- m.add_string(self._sigencode(r, s))
- return m
- def verify_ssh_sig(self, data, msg):
- if msg.get_text() != self.ecdsa_curve.key_format_identifier:
- return False
- sig = msg.get_binary()
- sigR, sigS = self._sigdecode(sig)
- signature = encode_dss_signature(sigR, sigS)
- try:
- self.verifying_key.verify(
- signature, data, ec.ECDSA(self.ecdsa_curve.hash_object())
- )
- except InvalidSignature:
- return False
- else:
- return True
- def write_private_key_file(self, filename, password=None):
- self._write_private_key_file(
- filename,
- self.signing_key,
- serialization.PrivateFormat.TraditionalOpenSSL,
- password=password,
- )
- def write_private_key(self, file_obj, password=None):
- self._write_private_key(
- file_obj,
- self.signing_key,
- serialization.PrivateFormat.TraditionalOpenSSL,
- password=password,
- )
- @classmethod
- def generate(cls, curve=ec.SECP256R1(), progress_func=None, bits=None):
- """
- Generate a new private ECDSA key. This factory function can be used to
- generate a new host key or authentication key.
- :param progress_func: Not used for this type of key.
- :returns: A new private key (`.ECDSAKey`) object
- """
- if bits is not None:
- curve = cls._ECDSA_CURVES.get_by_key_length(bits)
- if curve is None:
- raise ValueError("Unsupported key length: {:d}".format(bits))
- curve = curve.curve_class()
- private_key = ec.generate_private_key(curve, backend=default_backend())
- return ECDSAKey(vals=(private_key, private_key.public_key()))
- # ...internals...
- def _from_private_key_file(self, filename, password):
- data = self._read_private_key_file("EC", filename, password)
- self._decode_key(data)
- def _from_private_key(self, file_obj, password):
- data = self._read_private_key("EC", file_obj, password)
- self._decode_key(data)
- def _decode_key(self, data):
- pkformat, data = data
- if pkformat == self._PRIVATE_KEY_FORMAT_ORIGINAL:
- try:
- key = serialization.load_der_private_key(
- data, password=None, backend=default_backend()
- )
- except (
- ValueError,
- AssertionError,
- TypeError,
- UnsupportedAlgorithm,
- ) as e:
- raise SSHException(str(e))
- elif pkformat == self._PRIVATE_KEY_FORMAT_OPENSSH:
- try:
- msg = Message(data)
- curve_name = msg.get_text()
- verkey = msg.get_binary() # noqa: F841
- sigkey = msg.get_mpint()
- name = "ecdsa-sha2-" + curve_name
- curve = self._ECDSA_CURVES.get_by_key_format_identifier(name)
- if not curve:
- raise SSHException("Invalid key curve identifier")
- key = ec.derive_private_key(
- sigkey, curve.curve_class(), default_backend()
- )
- except Exception as e:
- # PKey._read_private_key_openssh() should check or return
- # keytype - parsing could fail for any reason due to wrong type
- raise SSHException(str(e))
- else:
- self._got_bad_key_format_id(pkformat)
- self.signing_key = key
- self.verifying_key = key.public_key()
- curve_class = key.curve.__class__
- self.ecdsa_curve = self._ECDSA_CURVES.get_by_curve_class(curve_class)
- def _sigencode(self, r, s):
- msg = Message()
- msg.add_mpint(r)
- msg.add_mpint(s)
- return msg.asbytes()
- def _sigdecode(self, sig):
- msg = Message(sig)
- r = msg.get_mpint()
- s = msg.get_mpint()
- return r, s
|