aead.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. # This file is dual licensed under the terms of the Apache License, Version
  2. # 2.0, and the BSD License. See the LICENSE file in the root of this repository
  3. # for complete details.
  4. import typing
  5. from cryptography.exceptions import InvalidTag
  6. if typing.TYPE_CHECKING:
  7. from cryptography.hazmat.backends.openssl.backend import Backend
  8. from cryptography.hazmat.primitives.ciphers.aead import (
  9. AESCCM,
  10. AESGCM,
  11. AESOCB3,
  12. AESSIV,
  13. ChaCha20Poly1305,
  14. )
  15. _AEAD_TYPES = typing.Union[
  16. AESCCM, AESGCM, AESOCB3, AESSIV, ChaCha20Poly1305
  17. ]
  18. _ENCRYPT = 1
  19. _DECRYPT = 0
  20. def _aead_cipher_name(cipher: "_AEAD_TYPES") -> bytes:
  21. from cryptography.hazmat.primitives.ciphers.aead import (
  22. AESCCM,
  23. AESGCM,
  24. AESOCB3,
  25. AESSIV,
  26. ChaCha20Poly1305,
  27. )
  28. if isinstance(cipher, ChaCha20Poly1305):
  29. return b"chacha20-poly1305"
  30. elif isinstance(cipher, AESCCM):
  31. return f"aes-{len(cipher._key) * 8}-ccm".encode("ascii")
  32. elif isinstance(cipher, AESOCB3):
  33. return f"aes-{len(cipher._key) * 8}-ocb".encode("ascii")
  34. elif isinstance(cipher, AESSIV):
  35. return f"aes-{len(cipher._key) * 8 // 2}-siv".encode("ascii")
  36. else:
  37. assert isinstance(cipher, AESGCM)
  38. return f"aes-{len(cipher._key) * 8}-gcm".encode("ascii")
  39. def _evp_cipher(cipher_name: bytes, backend: "Backend"):
  40. if cipher_name.endswith(b"-siv"):
  41. evp_cipher = backend._lib.EVP_CIPHER_fetch(
  42. backend._ffi.NULL,
  43. cipher_name,
  44. backend._ffi.NULL,
  45. )
  46. backend.openssl_assert(evp_cipher != backend._ffi.NULL)
  47. evp_cipher = backend._ffi.gc(evp_cipher, backend._lib.EVP_CIPHER_free)
  48. else:
  49. evp_cipher = backend._lib.EVP_get_cipherbyname(cipher_name)
  50. backend.openssl_assert(evp_cipher != backend._ffi.NULL)
  51. return evp_cipher
  52. def _aead_setup(
  53. backend: "Backend",
  54. cipher_name: bytes,
  55. key: bytes,
  56. nonce: bytes,
  57. tag: typing.Optional[bytes],
  58. tag_len: int,
  59. operation: int,
  60. ):
  61. evp_cipher = _evp_cipher(cipher_name, backend)
  62. ctx = backend._lib.EVP_CIPHER_CTX_new()
  63. ctx = backend._ffi.gc(ctx, backend._lib.EVP_CIPHER_CTX_free)
  64. res = backend._lib.EVP_CipherInit_ex(
  65. ctx,
  66. evp_cipher,
  67. backend._ffi.NULL,
  68. backend._ffi.NULL,
  69. backend._ffi.NULL,
  70. int(operation == _ENCRYPT),
  71. )
  72. backend.openssl_assert(res != 0)
  73. res = backend._lib.EVP_CIPHER_CTX_set_key_length(ctx, len(key))
  74. backend.openssl_assert(res != 0)
  75. res = backend._lib.EVP_CIPHER_CTX_ctrl(
  76. ctx,
  77. backend._lib.EVP_CTRL_AEAD_SET_IVLEN,
  78. len(nonce),
  79. backend._ffi.NULL,
  80. )
  81. backend.openssl_assert(res != 0)
  82. if operation == _DECRYPT:
  83. assert tag is not None
  84. res = backend._lib.EVP_CIPHER_CTX_ctrl(
  85. ctx, backend._lib.EVP_CTRL_AEAD_SET_TAG, len(tag), tag
  86. )
  87. backend.openssl_assert(res != 0)
  88. elif cipher_name.endswith(b"-ccm"):
  89. res = backend._lib.EVP_CIPHER_CTX_ctrl(
  90. ctx, backend._lib.EVP_CTRL_AEAD_SET_TAG, tag_len, backend._ffi.NULL
  91. )
  92. backend.openssl_assert(res != 0)
  93. nonce_ptr = backend._ffi.from_buffer(nonce)
  94. key_ptr = backend._ffi.from_buffer(key)
  95. res = backend._lib.EVP_CipherInit_ex(
  96. ctx,
  97. backend._ffi.NULL,
  98. backend._ffi.NULL,
  99. key_ptr,
  100. nonce_ptr,
  101. int(operation == _ENCRYPT),
  102. )
  103. backend.openssl_assert(res != 0)
  104. return ctx
  105. def _set_length(backend: "Backend", ctx, data_len: int) -> None:
  106. intptr = backend._ffi.new("int *")
  107. res = backend._lib.EVP_CipherUpdate(
  108. ctx, backend._ffi.NULL, intptr, backend._ffi.NULL, data_len
  109. )
  110. backend.openssl_assert(res != 0)
  111. def _process_aad(backend: "Backend", ctx, associated_data: bytes) -> None:
  112. outlen = backend._ffi.new("int *")
  113. res = backend._lib.EVP_CipherUpdate(
  114. ctx, backend._ffi.NULL, outlen, associated_data, len(associated_data)
  115. )
  116. backend.openssl_assert(res != 0)
  117. def _process_data(backend: "Backend", ctx, data: bytes) -> bytes:
  118. outlen = backend._ffi.new("int *")
  119. buf = backend._ffi.new("unsigned char[]", len(data))
  120. res = backend._lib.EVP_CipherUpdate(ctx, buf, outlen, data, len(data))
  121. if res == 0:
  122. # AES SIV can error here if the data is invalid on decrypt
  123. backend._consume_errors()
  124. raise InvalidTag
  125. return backend._ffi.buffer(buf, outlen[0])[:]
  126. def _encrypt(
  127. backend: "Backend",
  128. cipher: "_AEAD_TYPES",
  129. nonce: bytes,
  130. data: bytes,
  131. associated_data: typing.List[bytes],
  132. tag_length: int,
  133. ) -> bytes:
  134. from cryptography.hazmat.primitives.ciphers.aead import AESCCM, AESSIV
  135. cipher_name = _aead_cipher_name(cipher)
  136. ctx = _aead_setup(
  137. backend, cipher_name, cipher._key, nonce, None, tag_length, _ENCRYPT
  138. )
  139. # CCM requires us to pass the length of the data before processing anything
  140. # However calling this with any other AEAD results in an error
  141. if isinstance(cipher, AESCCM):
  142. _set_length(backend, ctx, len(data))
  143. for ad in associated_data:
  144. _process_aad(backend, ctx, ad)
  145. processed_data = _process_data(backend, ctx, data)
  146. outlen = backend._ffi.new("int *")
  147. # All AEADs we support besides OCB are streaming so they return nothing
  148. # in finalization. OCB can return up to (16 byte block - 1) bytes so
  149. # we need a buffer here too.
  150. buf = backend._ffi.new("unsigned char[]", 16)
  151. res = backend._lib.EVP_CipherFinal_ex(ctx, buf, outlen)
  152. backend.openssl_assert(res != 0)
  153. processed_data += backend._ffi.buffer(buf, outlen[0])[:]
  154. tag_buf = backend._ffi.new("unsigned char[]", tag_length)
  155. res = backend._lib.EVP_CIPHER_CTX_ctrl(
  156. ctx, backend._lib.EVP_CTRL_AEAD_GET_TAG, tag_length, tag_buf
  157. )
  158. backend.openssl_assert(res != 0)
  159. tag = backend._ffi.buffer(tag_buf)[:]
  160. if isinstance(cipher, AESSIV):
  161. # RFC 5297 defines the output as IV || C, where the tag we generate is
  162. # the "IV" and C is the ciphertext. This is the opposite of our
  163. # other AEADs, which are Ciphertext || Tag
  164. backend.openssl_assert(len(tag) == 16)
  165. return tag + processed_data
  166. else:
  167. return processed_data + tag
  168. def _decrypt(
  169. backend: "Backend",
  170. cipher: "_AEAD_TYPES",
  171. nonce: bytes,
  172. data: bytes,
  173. associated_data: typing.List[bytes],
  174. tag_length: int,
  175. ) -> bytes:
  176. from cryptography.hazmat.primitives.ciphers.aead import AESCCM, AESSIV
  177. if len(data) < tag_length:
  178. raise InvalidTag
  179. if isinstance(cipher, AESSIV):
  180. # RFC 5297 defines the output as IV || C, where the tag we generate is
  181. # the "IV" and C is the ciphertext. This is the opposite of our
  182. # other AEADs, which are Ciphertext || Tag
  183. tag = data[:tag_length]
  184. data = data[tag_length:]
  185. else:
  186. tag = data[-tag_length:]
  187. data = data[:-tag_length]
  188. cipher_name = _aead_cipher_name(cipher)
  189. ctx = _aead_setup(
  190. backend, cipher_name, cipher._key, nonce, tag, tag_length, _DECRYPT
  191. )
  192. # CCM requires us to pass the length of the data before processing anything
  193. # However calling this with any other AEAD results in an error
  194. if isinstance(cipher, AESCCM):
  195. _set_length(backend, ctx, len(data))
  196. for ad in associated_data:
  197. _process_aad(backend, ctx, ad)
  198. # CCM has a different error path if the tag doesn't match. Errors are
  199. # raised in Update and Final is irrelevant.
  200. if isinstance(cipher, AESCCM):
  201. outlen = backend._ffi.new("int *")
  202. buf = backend._ffi.new("unsigned char[]", len(data))
  203. res = backend._lib.EVP_CipherUpdate(ctx, buf, outlen, data, len(data))
  204. if res != 1:
  205. backend._consume_errors()
  206. raise InvalidTag
  207. processed_data = backend._ffi.buffer(buf, outlen[0])[:]
  208. else:
  209. processed_data = _process_data(backend, ctx, data)
  210. outlen = backend._ffi.new("int *")
  211. # OCB can return up to 15 bytes (16 byte block - 1) in finalization
  212. buf = backend._ffi.new("unsigned char[]", 16)
  213. res = backend._lib.EVP_CipherFinal_ex(ctx, buf, outlen)
  214. processed_data += backend._ffi.buffer(buf, outlen[0])[:]
  215. if res == 0:
  216. backend._consume_errors()
  217. raise InvalidTag
  218. return processed_data