api_jwk.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import json
  2. from .algorithms import get_default_algorithms
  3. from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError
  4. class PyJWK:
  5. def __init__(self, jwk_data, algorithm=None):
  6. self._algorithms = get_default_algorithms()
  7. self._jwk_data = jwk_data
  8. kty = self._jwk_data.get("kty", None)
  9. if not kty:
  10. raise InvalidKeyError(f"kty is not found: {self._jwk_data}")
  11. if not algorithm and isinstance(self._jwk_data, dict):
  12. algorithm = self._jwk_data.get("alg", None)
  13. if not algorithm:
  14. # Determine alg with kty (and crv).
  15. crv = self._jwk_data.get("crv", None)
  16. if kty == "EC":
  17. if crv == "P-256" or not crv:
  18. algorithm = "ES256"
  19. elif crv == "P-384":
  20. algorithm = "ES384"
  21. elif crv == "P-521":
  22. algorithm = "ES512"
  23. elif crv == "secp256k1":
  24. algorithm = "ES256K"
  25. else:
  26. raise InvalidKeyError(f"Unsupported crv: {crv}")
  27. elif kty == "RSA":
  28. algorithm = "RS256"
  29. elif kty == "oct":
  30. algorithm = "HS256"
  31. elif kty == "OKP":
  32. if not crv:
  33. raise InvalidKeyError(f"crv is not found: {self._jwk_data}")
  34. if crv == "Ed25519":
  35. algorithm = "EdDSA"
  36. else:
  37. raise InvalidKeyError(f"Unsupported crv: {crv}")
  38. else:
  39. raise InvalidKeyError(f"Unsupported kty: {kty}")
  40. self.Algorithm = self._algorithms.get(algorithm)
  41. if not self.Algorithm:
  42. raise PyJWKError(f"Unable to find a algorithm for key: {self._jwk_data}")
  43. self.key = self.Algorithm.from_jwk(self._jwk_data)
  44. @staticmethod
  45. def from_dict(obj, algorithm=None):
  46. return PyJWK(obj, algorithm)
  47. @staticmethod
  48. def from_json(data, algorithm=None):
  49. obj = json.loads(data)
  50. return PyJWK.from_dict(obj, algorithm)
  51. @property
  52. def key_type(self):
  53. return self._jwk_data.get("kty", None)
  54. @property
  55. def key_id(self):
  56. return self._jwk_data.get("kid", None)
  57. @property
  58. def public_key_use(self):
  59. return self._jwk_data.get("use", None)
  60. class PyJWKSet:
  61. def __init__(self, keys):
  62. self.keys = []
  63. if not keys or not isinstance(keys, list):
  64. raise PyJWKSetError("Invalid JWK Set value")
  65. if len(keys) == 0:
  66. raise PyJWKSetError("The JWK Set did not contain any keys")
  67. for key in keys:
  68. self.keys.append(PyJWK(key))
  69. @staticmethod
  70. def from_dict(obj):
  71. keys = obj.get("keys", [])
  72. return PyJWKSet(keys)
  73. @staticmethod
  74. def from_json(data):
  75. obj = json.loads(data)
  76. return PyJWKSet.from_dict(obj)
  77. def __getitem__(self, kid):
  78. for key in self.keys:
  79. if key.key_id == kid:
  80. return key
  81. raise KeyError(f"keyset has no key for kid: {kid}")