scrypt.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # This file is dual licensed under the terms of the Apache License, Version
  2. # 2.0, and the BSD License. See the LICENSE file in the root of this repository
  3. # for complete details.
  4. import sys
  5. import typing
  6. from cryptography import utils
  7. from cryptography.exceptions import (
  8. AlreadyFinalized,
  9. InvalidKey,
  10. UnsupportedAlgorithm,
  11. )
  12. from cryptography.hazmat.primitives import constant_time
  13. from cryptography.hazmat.primitives.kdf import KeyDerivationFunction
  14. # This is used by the scrypt tests to skip tests that require more memory
  15. # than the MEM_LIMIT
  16. _MEM_LIMIT = sys.maxsize // 2
  17. class Scrypt(KeyDerivationFunction):
  18. def __init__(
  19. self,
  20. salt: bytes,
  21. length: int,
  22. n: int,
  23. r: int,
  24. p: int,
  25. backend: typing.Any = None,
  26. ):
  27. from cryptography.hazmat.backends.openssl.backend import (
  28. backend as ossl,
  29. )
  30. if not ossl.scrypt_supported():
  31. raise UnsupportedAlgorithm(
  32. "This version of OpenSSL does not support scrypt"
  33. )
  34. self._length = length
  35. utils._check_bytes("salt", salt)
  36. if n < 2 or (n & (n - 1)) != 0:
  37. raise ValueError("n must be greater than 1 and be a power of 2.")
  38. if r < 1:
  39. raise ValueError("r must be greater than or equal to 1.")
  40. if p < 1:
  41. raise ValueError("p must be greater than or equal to 1.")
  42. self._used = False
  43. self._salt = salt
  44. self._n = n
  45. self._r = r
  46. self._p = p
  47. def derive(self, key_material: bytes) -> bytes:
  48. if self._used:
  49. raise AlreadyFinalized("Scrypt instances can only be used once.")
  50. self._used = True
  51. utils._check_byteslike("key_material", key_material)
  52. from cryptography.hazmat.backends.openssl.backend import backend
  53. return backend.derive_scrypt(
  54. key_material, self._salt, self._length, self._n, self._r, self._p
  55. )
  56. def verify(self, key_material: bytes, expected_key: bytes) -> None:
  57. derived_key = self.derive(key_material)
  58. if not constant_time.bytes_eq(derived_key, expected_key):
  59. raise InvalidKey("Keys do not match.")