123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304 |
- # Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
- #
- # This file is part of paramiko.
- #
- # Paramiko is free software; you can redistribute it and/or modify it under the
- # terms of the GNU Lesser General Public License as published by the Free
- # Software Foundation; either version 2.1 of the License, or (at your option)
- # any later version.
- #
- # Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
- # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- # A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- # details.
- #
- # You should have received a copy of the GNU Lesser General Public License
- # along with Paramiko; if not, write to the Free Software Foundation, Inc.,
- # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
- """
- Useful functions used by the rest of paramiko.
- """
- from __future__ import generators
- import errno
- import sys
- import struct
- import traceback
- import threading
- import logging
- from paramiko.common import DEBUG, zero_byte, xffffffff, max_byte
- from paramiko.py3compat import PY2, long, byte_chr, byte_ord, b
- from paramiko.config import SSHConfig
- def inflate_long(s, always_positive=False):
- """turns a normalized byte string into a long-int
- (adapted from Crypto.Util.number)"""
- out = long(0)
- negative = 0
- if not always_positive and (len(s) > 0) and (byte_ord(s[0]) >= 0x80):
- negative = 1
- if len(s) % 4:
- filler = zero_byte
- if negative:
- filler = max_byte
- # never convert this to ``s +=`` because this is a string, not a number
- # noinspection PyAugmentAssignment
- s = filler * (4 - len(s) % 4) + s
- for i in range(0, len(s), 4):
- out = (out << 32) + struct.unpack(">I", s[i : i + 4])[0]
- if negative:
- out -= long(1) << (8 * len(s))
- return out
- deflate_zero = zero_byte if PY2 else 0
- deflate_ff = max_byte if PY2 else 0xff
- def deflate_long(n, add_sign_padding=True):
- """turns a long-int into a normalized byte string
- (adapted from Crypto.Util.number)"""
- # after much testing, this algorithm was deemed to be the fastest
- s = bytes()
- n = long(n)
- while (n != 0) and (n != -1):
- s = struct.pack(">I", n & xffffffff) + s
- n >>= 32
- # strip off leading zeros, FFs
- for i in enumerate(s):
- if (n == 0) and (i[1] != deflate_zero):
- break
- if (n == -1) and (i[1] != deflate_ff):
- break
- else:
- # degenerate case, n was either 0 or -1
- i = (0,)
- if n == 0:
- s = zero_byte
- else:
- s = max_byte
- s = s[i[0] :]
- if add_sign_padding:
- if (n == 0) and (byte_ord(s[0]) >= 0x80):
- s = zero_byte + s
- if (n == -1) and (byte_ord(s[0]) < 0x80):
- s = max_byte + s
- return s
- def format_binary(data, prefix=""):
- x = 0
- out = []
- while len(data) > x + 16:
- out.append(format_binary_line(data[x : x + 16]))
- x += 16
- if x < len(data):
- out.append(format_binary_line(data[x:]))
- return [prefix + line for line in out]
- def format_binary_line(data):
- left = " ".join(["{:02X}".format(byte_ord(c)) for c in data])
- right = "".join(
- [".{:c}..".format(byte_ord(c))[(byte_ord(c) + 63) // 95] for c in data]
- )
- return "{:50s} {}".format(left, right)
- def safe_string(s):
- out = b""
- for c in s:
- i = byte_ord(c)
- if 32 <= i <= 127:
- out += byte_chr(i)
- else:
- out += b("%{:02X}".format(i))
- return out
- def bit_length(n):
- try:
- return n.bit_length()
- except AttributeError:
- norm = deflate_long(n, False)
- hbyte = byte_ord(norm[0])
- if hbyte == 0:
- return 1
- bitlen = len(norm) * 8
- while not (hbyte & 0x80):
- hbyte <<= 1
- bitlen -= 1
- return bitlen
- def tb_strings():
- return "".join(traceback.format_exception(*sys.exc_info())).split("\n")
- def generate_key_bytes(hash_alg, salt, key, nbytes):
- """
- Given a password, passphrase, or other human-source key, scramble it
- through a secure hash into some keyworthy bytes. This specific algorithm
- is used for encrypting/decrypting private key files.
- :param function hash_alg: A function which creates a new hash object, such
- as ``hashlib.sha256``.
- :param salt: data to salt the hash with.
- :type salt: byte string
- :param str key: human-entered password or passphrase.
- :param int nbytes: number of bytes to generate.
- :return: Key data `str`
- """
- keydata = bytes()
- digest = bytes()
- if len(salt) > 8:
- salt = salt[:8]
- while nbytes > 0:
- hash_obj = hash_alg()
- if len(digest) > 0:
- hash_obj.update(digest)
- hash_obj.update(b(key))
- hash_obj.update(salt)
- digest = hash_obj.digest()
- size = min(nbytes, len(digest))
- keydata += digest[:size]
- nbytes -= size
- return keydata
- def load_host_keys(filename):
- """
- Read a file of known SSH host keys, in the format used by openssh, and
- return a compound dict of ``hostname -> keytype ->`` `PKey
- <paramiko.pkey.PKey>`. The hostname may be an IP address or DNS name. The
- keytype will be either ``"ssh-rsa"`` or ``"ssh-dss"``.
- This type of file unfortunately doesn't exist on Windows, but on posix,
- it will usually be stored in ``os.path.expanduser("~/.ssh/known_hosts")``.
- Since 1.5.3, this is just a wrapper around `.HostKeys`.
- :param str filename: name of the file to read host keys from
- :return:
- nested dict of `.PKey` objects, indexed by hostname and then keytype
- """
- from paramiko.hostkeys import HostKeys
- return HostKeys(filename)
- def parse_ssh_config(file_obj):
- """
- Provided only as a backward-compatible wrapper around `.SSHConfig`.
- .. deprecated:: 2.7
- Use `SSHConfig.from_file` instead.
- """
- config = SSHConfig()
- config.parse(file_obj)
- return config
- def lookup_ssh_host_config(hostname, config):
- """
- Provided only as a backward-compatible wrapper around `.SSHConfig`.
- """
- return config.lookup(hostname)
- def mod_inverse(x, m):
- # it's crazy how small Python can make this function.
- u1, u2, u3 = 1, 0, m
- v1, v2, v3 = 0, 1, x
- while v3 > 0:
- q = u3 // v3
- u1, v1 = v1, u1 - v1 * q
- u2, v2 = v2, u2 - v2 * q
- u3, v3 = v3, u3 - v3 * q
- if u2 < 0:
- u2 += m
- return u2
- _g_thread_data = threading.local()
- _g_thread_counter = 0
- _g_thread_lock = threading.Lock()
- def get_thread_id():
- global _g_thread_data, _g_thread_counter, _g_thread_lock
- try:
- return _g_thread_data.id
- except AttributeError:
- with _g_thread_lock:
- _g_thread_counter += 1
- _g_thread_data.id = _g_thread_counter
- return _g_thread_data.id
- def log_to_file(filename, level=DEBUG):
- """send paramiko logs to a logfile,
- if they're not already going somewhere"""
- logger = logging.getLogger("paramiko")
- if len(logger.handlers) > 0:
- return
- logger.setLevel(level)
- f = open(filename, "a")
- handler = logging.StreamHandler(f)
- frm = "%(levelname)-.3s [%(asctime)s.%(msecs)03d] thr=%(_threadid)-3d"
- frm += " %(name)s: %(message)s"
- handler.setFormatter(logging.Formatter(frm, "%Y%m%d-%H:%M:%S"))
- logger.addHandler(handler)
- # make only one filter object, so it doesn't get applied more than once
- class PFilter(object):
- def filter(self, record):
- record._threadid = get_thread_id()
- return True
- _pfilter = PFilter()
- def get_logger(name):
- logger = logging.getLogger(name)
- logger.addFilter(_pfilter)
- return logger
- def retry_on_signal(function):
- """Retries function until it doesn't raise an EINTR error"""
- while True:
- try:
- return function()
- except EnvironmentError as e:
- if e.errno != errno.EINTR:
- raise
- def constant_time_bytes_eq(a, b):
- if len(a) != len(b):
- return False
- res = 0
- # noinspection PyUnresolvedReferences
- for i in (xrange if PY2 else range)(len(a)): # noqa: F821
- res |= byte_ord(a[i]) ^ byte_ord(b[i])
- return res == 0
- class ClosingContextManager(object):
- def __enter__(self):
- return self
- def __exit__(self, type, value, traceback):
- self.close()
- def clamp_value(minimum, val, maximum):
- return max(minimum, min(val, maximum))
|