71 lines
2.2 KiB
Python
71 lines
2.2 KiB
Python
import traceback
|
|
from urllib.parse import parse_qs
|
|
|
|
from channels.auth import AuthMiddlewareStack
|
|
from channels.db import database_sync_to_async
|
|
from django.conf import settings
|
|
from django.contrib.auth import get_user_model
|
|
from django.contrib.auth.models import AnonymousUser
|
|
from django.db import close_old_connections
|
|
from jwt import DecodeError, ExpiredSignatureError, InvalidSignatureError
|
|
from jwt import decode as jwt_decode
|
|
|
|
User = get_user_model()
|
|
|
|
|
|
class JWTAuthMiddleware:
|
|
def __init__(self, app):
|
|
self.app = app
|
|
|
|
async def __call__(self, scope, receive, send):
|
|
close_old_connections()
|
|
try:
|
|
if jwt_token_list := parse_qs(scope["query_string"].decode("utf8")).get(
|
|
"token", None
|
|
):
|
|
jwt_token = jwt_token_list[0]
|
|
jwt_payload = self.get_payload(jwt_token)
|
|
user_credentials = self.get_user_credentials(jwt_payload)
|
|
user = await self.get_logged_in_user(user_credentials)
|
|
scope["user"] = user
|
|
else:
|
|
scope["user"] = AnonymousUser()
|
|
except (
|
|
InvalidSignatureError,
|
|
KeyError,
|
|
ExpiredSignatureError,
|
|
DecodeError,
|
|
):
|
|
traceback.print_exc()
|
|
except Exception as e:
|
|
print(e)
|
|
scope["user"] = AnonymousUser()
|
|
return await self.app(scope, receive, send)
|
|
|
|
def get_payload(self, jwt_token):
|
|
payload = jwt_decode(jwt_token, settings.SECRET_KEY, algorithms=["HS256"])
|
|
return payload
|
|
|
|
def get_user_credentials(self, payload):
|
|
"""
|
|
method to get user credentials from jwt token payload.
|
|
defaults to user id.
|
|
"""
|
|
user_id = payload["user_id"]
|
|
return user_id
|
|
|
|
async def get_logged_in_user(self, user_id):
|
|
user = await self.get_user(user_id)
|
|
return user
|
|
|
|
@database_sync_to_async
|
|
def get_user(self, user_id):
|
|
try:
|
|
return User.objects.get(id=user_id)
|
|
except User.DoesNotExist:
|
|
return AnonymousUser()
|
|
|
|
|
|
def JWTAuthMiddlewareStack(app):
|
|
return JWTAuthMiddleware(AuthMiddlewareStack(app))
|