utils.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. import base64
  2. import binascii
  3. import re
  4. from typing import Any, Union
  5. try:
  6. from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurve
  7. from cryptography.hazmat.primitives.asymmetric.utils import (
  8. decode_dss_signature,
  9. encode_dss_signature,
  10. )
  11. except ModuleNotFoundError:
  12. EllipticCurve = Any # type: ignore
  13. def force_bytes(value: Union[str, bytes]) -> bytes:
  14. if isinstance(value, str):
  15. return value.encode("utf-8")
  16. elif isinstance(value, bytes):
  17. return value
  18. else:
  19. raise TypeError("Expected a string value")
  20. def base64url_decode(input: Union[str, bytes]) -> bytes:
  21. if isinstance(input, str):
  22. input = input.encode("ascii")
  23. rem = len(input) % 4
  24. if rem > 0:
  25. input += b"=" * (4 - rem)
  26. return base64.urlsafe_b64decode(input)
  27. def base64url_encode(input: bytes) -> bytes:
  28. return base64.urlsafe_b64encode(input).replace(b"=", b"")
  29. def to_base64url_uint(val: int) -> bytes:
  30. if val < 0:
  31. raise ValueError("Must be a positive integer")
  32. int_bytes = bytes_from_int(val)
  33. if len(int_bytes) == 0:
  34. int_bytes = b"\x00"
  35. return base64url_encode(int_bytes)
  36. def from_base64url_uint(val: Union[str, bytes]) -> int:
  37. if isinstance(val, str):
  38. val = val.encode("ascii")
  39. data = base64url_decode(val)
  40. return int.from_bytes(data, byteorder="big")
  41. def number_to_bytes(num: int, num_bytes: int) -> bytes:
  42. padded_hex = "%0*x" % (2 * num_bytes, num)
  43. return binascii.a2b_hex(padded_hex.encode("ascii"))
  44. def bytes_to_number(string: bytes) -> int:
  45. return int(binascii.b2a_hex(string), 16)
  46. def bytes_from_int(val: int) -> bytes:
  47. remaining = val
  48. byte_length = 0
  49. while remaining != 0:
  50. remaining >>= 8
  51. byte_length += 1
  52. return val.to_bytes(byte_length, "big", signed=False)
  53. def der_to_raw_signature(der_sig: bytes, curve: EllipticCurve) -> bytes:
  54. num_bits = curve.key_size
  55. num_bytes = (num_bits + 7) // 8
  56. r, s = decode_dss_signature(der_sig)
  57. return number_to_bytes(r, num_bytes) + number_to_bytes(s, num_bytes)
  58. def raw_to_der_signature(raw_sig: bytes, curve: EllipticCurve) -> bytes:
  59. num_bits = curve.key_size
  60. num_bytes = (num_bits + 7) // 8
  61. if len(raw_sig) != 2 * num_bytes:
  62. raise ValueError("Invalid signature")
  63. r = bytes_to_number(raw_sig[:num_bytes])
  64. s = bytes_to_number(raw_sig[num_bytes:])
  65. return encode_dss_signature(r, s)
  66. # Based on https://github.com/hynek/pem/blob/7ad94db26b0bc21d10953f5dbad3acfdfacf57aa/src/pem/_core.py#L224-L252
  67. _PEMS = {
  68. b"CERTIFICATE",
  69. b"TRUSTED CERTIFICATE",
  70. b"PRIVATE KEY",
  71. b"PUBLIC KEY",
  72. b"ENCRYPTED PRIVATE KEY",
  73. b"OPENSSH PRIVATE KEY",
  74. b"DSA PRIVATE KEY",
  75. b"RSA PRIVATE KEY",
  76. b"RSA PUBLIC KEY",
  77. b"EC PRIVATE KEY",
  78. b"DH PARAMETERS",
  79. b"NEW CERTIFICATE REQUEST",
  80. b"CERTIFICATE REQUEST",
  81. b"SSH2 PUBLIC KEY",
  82. b"SSH2 ENCRYPTED PRIVATE KEY",
  83. b"X509 CRL",
  84. }
  85. _PEM_RE = re.compile(
  86. b"----[- ]BEGIN ("
  87. + b"|".join(_PEMS)
  88. + b""")[- ]----\r?
  89. .+?\r?
  90. ----[- ]END \\1[- ]----\r?\n?""",
  91. re.DOTALL,
  92. )
  93. def is_pem_format(key: bytes) -> bool:
  94. return bool(_PEM_RE.search(key))
  95. # Based on https://github.com/pyca/cryptography/blob/bcb70852d577b3f490f015378c75cba74986297b/src/cryptography/hazmat/primitives/serialization/ssh.py#L40-L46
  96. _CERT_SUFFIX = b"-cert-v01@openssh.com"
  97. _SSH_PUBKEY_RC = re.compile(br"\A(\S+)[ \t]+(\S+)")
  98. _SSH_KEY_FORMATS = [
  99. b"ssh-ed25519",
  100. b"ssh-rsa",
  101. b"ssh-dss",
  102. b"ecdsa-sha2-nistp256",
  103. b"ecdsa-sha2-nistp384",
  104. b"ecdsa-sha2-nistp521",
  105. ]
  106. def is_ssh_key(key: bytes) -> bool:
  107. if any(string_value in key for string_value in _SSH_KEY_FORMATS):
  108. return True
  109. ssh_pubkey_match = _SSH_PUBKEY_RC.match(key)
  110. if ssh_pubkey_match:
  111. key_type = ssh_pubkey_match.group(1)
  112. if _CERT_SUFFIX == key_type[-len(_CERT_SUFFIX) :]:
  113. return True
  114. return False