concatkdf.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. # This file is dual licensed under the terms of the Apache License, Version
  2. # 2.0, and the BSD License. See the LICENSE file in the root of this repository
  3. # for complete details.
  4. import typing
  5. from cryptography import utils
  6. from cryptography.exceptions import (
  7. AlreadyFinalized,
  8. InvalidKey,
  9. )
  10. from cryptography.hazmat.primitives import constant_time, hashes, hmac
  11. from cryptography.hazmat.primitives.kdf import KeyDerivationFunction
  12. def _int_to_u32be(n: int) -> bytes:
  13. return n.to_bytes(length=4, byteorder="big")
  14. def _common_args_checks(
  15. algorithm: hashes.HashAlgorithm,
  16. length: int,
  17. otherinfo: typing.Optional[bytes],
  18. ) -> None:
  19. max_length = algorithm.digest_size * (2**32 - 1)
  20. if length > max_length:
  21. raise ValueError(
  22. "Cannot derive keys larger than {} bits.".format(max_length)
  23. )
  24. if otherinfo is not None:
  25. utils._check_bytes("otherinfo", otherinfo)
  26. def _concatkdf_derive(
  27. key_material: bytes,
  28. length: int,
  29. auxfn: typing.Callable[[], hashes.HashContext],
  30. otherinfo: bytes,
  31. ) -> bytes:
  32. utils._check_byteslike("key_material", key_material)
  33. output = [b""]
  34. outlen = 0
  35. counter = 1
  36. while length > outlen:
  37. h = auxfn()
  38. h.update(_int_to_u32be(counter))
  39. h.update(key_material)
  40. h.update(otherinfo)
  41. output.append(h.finalize())
  42. outlen += len(output[-1])
  43. counter += 1
  44. return b"".join(output)[:length]
  45. class ConcatKDFHash(KeyDerivationFunction):
  46. def __init__(
  47. self,
  48. algorithm: hashes.HashAlgorithm,
  49. length: int,
  50. otherinfo: typing.Optional[bytes],
  51. backend: typing.Any = None,
  52. ):
  53. _common_args_checks(algorithm, length, otherinfo)
  54. self._algorithm = algorithm
  55. self._length = length
  56. self._otherinfo: bytes = otherinfo if otherinfo is not None else b""
  57. self._used = False
  58. def _hash(self) -> hashes.Hash:
  59. return hashes.Hash(self._algorithm)
  60. def derive(self, key_material: bytes) -> bytes:
  61. if self._used:
  62. raise AlreadyFinalized
  63. self._used = True
  64. return _concatkdf_derive(
  65. key_material, self._length, self._hash, self._otherinfo
  66. )
  67. def verify(self, key_material: bytes, expected_key: bytes) -> None:
  68. if not constant_time.bytes_eq(self.derive(key_material), expected_key):
  69. raise InvalidKey
  70. class ConcatKDFHMAC(KeyDerivationFunction):
  71. def __init__(
  72. self,
  73. algorithm: hashes.HashAlgorithm,
  74. length: int,
  75. salt: typing.Optional[bytes],
  76. otherinfo: typing.Optional[bytes],
  77. backend: typing.Any = None,
  78. ):
  79. _common_args_checks(algorithm, length, otherinfo)
  80. self._algorithm = algorithm
  81. self._length = length
  82. self._otherinfo: bytes = otherinfo if otherinfo is not None else b""
  83. if algorithm.block_size is None:
  84. raise TypeError(
  85. "{} is unsupported for ConcatKDF".format(algorithm.name)
  86. )
  87. if salt is None:
  88. salt = b"\x00" * algorithm.block_size
  89. else:
  90. utils._check_bytes("salt", salt)
  91. self._salt = salt
  92. self._used = False
  93. def _hmac(self) -> hmac.HMAC:
  94. return hmac.HMAC(self._salt, self._algorithm)
  95. def derive(self, key_material: bytes) -> bytes:
  96. if self._used:
  97. raise AlreadyFinalized
  98. self._used = True
  99. return _concatkdf_derive(
  100. key_material, self._length, self._hmac, self._otherinfo
  101. )
  102. def verify(self, key_material: bytes, expected_key: bytes) -> None:
  103. if not constant_time.bytes_eq(self.derive(key_material), expected_key):
  104. raise InvalidKey