hkdf.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. # This file is dual licensed under the terms of the Apache License, Version
  2. # 2.0, and the BSD License. See the LICENSE file in the root of this repository
  3. # for complete details.
  4. import typing
  5. from cryptography import utils
  6. from cryptography.exceptions import (
  7. AlreadyFinalized,
  8. InvalidKey,
  9. )
  10. from cryptography.hazmat.primitives import constant_time, hashes, hmac
  11. from cryptography.hazmat.primitives.kdf import KeyDerivationFunction
  12. class HKDF(KeyDerivationFunction):
  13. def __init__(
  14. self,
  15. algorithm: hashes.HashAlgorithm,
  16. length: int,
  17. salt: typing.Optional[bytes],
  18. info: typing.Optional[bytes],
  19. backend: typing.Any = None,
  20. ):
  21. self._algorithm = algorithm
  22. if salt is None:
  23. salt = b"\x00" * self._algorithm.digest_size
  24. else:
  25. utils._check_bytes("salt", salt)
  26. self._salt = salt
  27. self._hkdf_expand = HKDFExpand(self._algorithm, length, info)
  28. def _extract(self, key_material: bytes) -> bytes:
  29. h = hmac.HMAC(self._salt, self._algorithm)
  30. h.update(key_material)
  31. return h.finalize()
  32. def derive(self, key_material: bytes) -> bytes:
  33. utils._check_byteslike("key_material", key_material)
  34. return self._hkdf_expand.derive(self._extract(key_material))
  35. def verify(self, key_material: bytes, expected_key: bytes) -> None:
  36. if not constant_time.bytes_eq(self.derive(key_material), expected_key):
  37. raise InvalidKey
  38. class HKDFExpand(KeyDerivationFunction):
  39. def __init__(
  40. self,
  41. algorithm: hashes.HashAlgorithm,
  42. length: int,
  43. info: typing.Optional[bytes],
  44. backend: typing.Any = None,
  45. ):
  46. self._algorithm = algorithm
  47. max_length = 255 * algorithm.digest_size
  48. if length > max_length:
  49. raise ValueError(
  50. "Cannot derive keys larger than {} octets.".format(max_length)
  51. )
  52. self._length = length
  53. if info is None:
  54. info = b""
  55. else:
  56. utils._check_bytes("info", info)
  57. self._info = info
  58. self._used = False
  59. def _expand(self, key_material: bytes) -> bytes:
  60. output = [b""]
  61. counter = 1
  62. while self._algorithm.digest_size * (len(output) - 1) < self._length:
  63. h = hmac.HMAC(key_material, self._algorithm)
  64. h.update(output[-1])
  65. h.update(self._info)
  66. h.update(bytes([counter]))
  67. output.append(h.finalize())
  68. counter += 1
  69. return b"".join(output)[: self._length]
  70. def derive(self, key_material: bytes) -> bytes:
  71. utils._check_byteslike("key_material", key_material)
  72. if self._used:
  73. raise AlreadyFinalized
  74. self._used = True
  75. return self._expand(key_material)
  76. def verify(self, key_material: bytes, expected_key: bytes) -> None:
  77. if not constant_time.bytes_eq(self.derive(key_material), expected_key):
  78. raise InvalidKey