123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153 |
- from django.contrib.auth import authenticate, get_user_model
- from django.contrib.auth.models import update_last_login
- from django.utils.translation import gettext_lazy as _
- from rest_framework import exceptions, serializers
- from rest_framework.exceptions import ValidationError
- from .settings import api_settings
- from .tokens import RefreshToken, SlidingToken, UntypedToken
- if api_settings.BLACKLIST_AFTER_ROTATION:
- from .token_blacklist.models import BlacklistedToken
- class PasswordField(serializers.CharField):
- def __init__(self, *args, **kwargs):
- kwargs.setdefault('style', {})
- kwargs['style']['input_type'] = 'password'
- kwargs['write_only'] = True
- super().__init__(*args, **kwargs)
- class TokenObtainSerializer(serializers.Serializer):
- username_field = get_user_model().USERNAME_FIELD
- default_error_messages = {
- 'no_active_account': _('No active account found with the given credentials')
- }
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.fields[self.username_field] = serializers.CharField()
- self.fields['password'] = PasswordField()
- def validate(self, attrs):
- authenticate_kwargs = {
- self.username_field: attrs[self.username_field],
- 'password': attrs['password'],
- }
- try:
- authenticate_kwargs['request'] = self.context['request']
- except KeyError:
- pass
- self.user = authenticate(**authenticate_kwargs)
- if not api_settings.USER_AUTHENTICATION_RULE(self.user):
- raise exceptions.AuthenticationFailed(
- self.error_messages['no_active_account'],
- 'no_active_account',
- )
- return {}
- @classmethod
- def get_token(cls, user):
- raise NotImplementedError('Must implement `get_token` method for `TokenObtainSerializer` subclasses')
- class TokenObtainPairSerializer(TokenObtainSerializer):
- @classmethod
- def get_token(cls, user):
- return RefreshToken.for_user(user)
- def validate(self, attrs):
- data = super().validate(attrs)
- refresh = self.get_token(self.user)
- data['refresh'] = str(refresh)
- data['access'] = str(refresh.access_token)
- if api_settings.UPDATE_LAST_LOGIN:
- update_last_login(None, self.user)
- return data
- class TokenObtainSlidingSerializer(TokenObtainSerializer):
- @classmethod
- def get_token(cls, user):
- return SlidingToken.for_user(user)
- def validate(self, attrs):
- data = super().validate(attrs)
- token = self.get_token(self.user)
- data['token'] = str(token)
- if api_settings.UPDATE_LAST_LOGIN:
- update_last_login(None, self.user)
- return data
- class TokenRefreshSerializer(serializers.Serializer):
- refresh = serializers.CharField()
- access = serializers.ReadOnlyField()
- def validate(self, attrs):
- refresh = RefreshToken(attrs['refresh'])
- data = {'access': str(refresh.access_token)}
- if api_settings.ROTATE_REFRESH_TOKENS:
- if api_settings.BLACKLIST_AFTER_ROTATION:
- try:
- # Attempt to blacklist the given refresh token
- refresh.blacklist()
- except AttributeError:
- # If blacklist app not installed, `blacklist` method will
- # not be present
- pass
- refresh.set_jti()
- refresh.set_exp()
- data['refresh'] = str(refresh)
- return data
- class TokenRefreshSlidingSerializer(serializers.Serializer):
- token = serializers.CharField()
- def validate(self, attrs):
- token = SlidingToken(attrs['token'])
- # Check that the timestamp in the "refresh_exp" claim has not
- # passed
- token.check_exp(api_settings.SLIDING_TOKEN_REFRESH_EXP_CLAIM)
- # Update the "exp" claim
- token.set_exp()
- return {'token': str(token)}
- class TokenVerifySerializer(serializers.Serializer):
- token = serializers.CharField()
- def validate(self, attrs):
- token = UntypedToken(attrs['token'])
- if api_settings.BLACKLIST_AFTER_ROTATION:
- jti = token.get(api_settings.JTI_CLAIM)
- if BlacklistedToken.objects.filter(token__jti=jti).exists():
- raise ValidationError("Token is blacklisted")
- return {}
|