serializers.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. from django.contrib.auth import authenticate, get_user_model
  2. from django.contrib.auth.models import update_last_login
  3. from django.utils.translation import gettext_lazy as _
  4. from rest_framework import exceptions, serializers
  5. from rest_framework.exceptions import ValidationError
  6. from .settings import api_settings
  7. from .tokens import RefreshToken, SlidingToken, UntypedToken
  8. if api_settings.BLACKLIST_AFTER_ROTATION:
  9. from .token_blacklist.models import BlacklistedToken
  10. class PasswordField(serializers.CharField):
  11. def __init__(self, *args, **kwargs):
  12. kwargs.setdefault('style', {})
  13. kwargs['style']['input_type'] = 'password'
  14. kwargs['write_only'] = True
  15. super().__init__(*args, **kwargs)
  16. class TokenObtainSerializer(serializers.Serializer):
  17. username_field = get_user_model().USERNAME_FIELD
  18. default_error_messages = {
  19. 'no_active_account': _('No active account found with the given credentials')
  20. }
  21. def __init__(self, *args, **kwargs):
  22. super().__init__(*args, **kwargs)
  23. self.fields[self.username_field] = serializers.CharField()
  24. self.fields['password'] = PasswordField()
  25. def validate(self, attrs):
  26. authenticate_kwargs = {
  27. self.username_field: attrs[self.username_field],
  28. 'password': attrs['password'],
  29. }
  30. try:
  31. authenticate_kwargs['request'] = self.context['request']
  32. except KeyError:
  33. pass
  34. self.user = authenticate(**authenticate_kwargs)
  35. if not api_settings.USER_AUTHENTICATION_RULE(self.user):
  36. raise exceptions.AuthenticationFailed(
  37. self.error_messages['no_active_account'],
  38. 'no_active_account',
  39. )
  40. return {}
  41. @classmethod
  42. def get_token(cls, user):
  43. raise NotImplementedError('Must implement `get_token` method for `TokenObtainSerializer` subclasses')
  44. class TokenObtainPairSerializer(TokenObtainSerializer):
  45. @classmethod
  46. def get_token(cls, user):
  47. return RefreshToken.for_user(user)
  48. def validate(self, attrs):
  49. data = super().validate(attrs)
  50. refresh = self.get_token(self.user)
  51. data['refresh'] = str(refresh)
  52. data['access'] = str(refresh.access_token)
  53. if api_settings.UPDATE_LAST_LOGIN:
  54. update_last_login(None, self.user)
  55. return data
  56. class TokenObtainSlidingSerializer(TokenObtainSerializer):
  57. @classmethod
  58. def get_token(cls, user):
  59. return SlidingToken.for_user(user)
  60. def validate(self, attrs):
  61. data = super().validate(attrs)
  62. token = self.get_token(self.user)
  63. data['token'] = str(token)
  64. if api_settings.UPDATE_LAST_LOGIN:
  65. update_last_login(None, self.user)
  66. return data
  67. class TokenRefreshSerializer(serializers.Serializer):
  68. refresh = serializers.CharField()
  69. access = serializers.ReadOnlyField()
  70. def validate(self, attrs):
  71. refresh = RefreshToken(attrs['refresh'])
  72. data = {'access': str(refresh.access_token)}
  73. if api_settings.ROTATE_REFRESH_TOKENS:
  74. if api_settings.BLACKLIST_AFTER_ROTATION:
  75. try:
  76. # Attempt to blacklist the given refresh token
  77. refresh.blacklist()
  78. except AttributeError:
  79. # If blacklist app not installed, `blacklist` method will
  80. # not be present
  81. pass
  82. refresh.set_jti()
  83. refresh.set_exp()
  84. data['refresh'] = str(refresh)
  85. return data
  86. class TokenRefreshSlidingSerializer(serializers.Serializer):
  87. token = serializers.CharField()
  88. def validate(self, attrs):
  89. token = SlidingToken(attrs['token'])
  90. # Check that the timestamp in the "refresh_exp" claim has not
  91. # passed
  92. token.check_exp(api_settings.SLIDING_TOKEN_REFRESH_EXP_CLAIM)
  93. # Update the "exp" claim
  94. token.set_exp()
  95. return {'token': str(token)}
  96. class TokenVerifySerializer(serializers.Serializer):
  97. token = serializers.CharField()
  98. def validate(self, attrs):
  99. token = UntypedToken(attrs['token'])
  100. if api_settings.BLACKLIST_AFTER_ROTATION:
  101. jti = token.get(api_settings.JTI_CLAIM)
  102. if BlacklistedToken.objects.filter(token__jti=jti).exists():
  103. raise ValidationError("Token is blacklisted")
  104. return {}