From f34df7d5b1f477411480f82282a9c413dd2296e0 Mon Sep 17 00:00:00 2001 From: Profitroll <47523801+profitrollgame@users.noreply.github.com> Date: Tue, 20 Dec 2022 11:36:54 +0100 Subject: [PATCH] Started creating auth system --- extensions/security.py | 56 ++++++++++++++++ modules/security.py | 143 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 199 insertions(+) create mode 100644 extensions/security.py create mode 100644 modules/security.py diff --git a/extensions/security.py b/extensions/security.py new file mode 100644 index 0000000..10a0ef2 --- /dev/null +++ b/extensions/security.py @@ -0,0 +1,56 @@ +from datetime import timedelta +from modules.database import col_users +from modules.app import app + +from fastapi import Depends, HTTPException, Security, Response +from starlette.status import HTTP_204_NO_CONTENT +from fastapi.security import ( + OAuth2PasswordRequestForm, +) + +from modules.security import ( + ACCESS_TOKEN_EXPIRE_DAYS, + Token, + User, + authenticate_user, + create_access_token, + get_current_active_user, + get_current_user, + get_password_hash +) + + +@app.post("/token", response_model=Token) +async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()): + user = authenticate_user(form_data.username, form_data.password) + if not user: + raise HTTPException(status_code=400, detail="Incorrect user or password") + access_token_expires = timedelta(days=ACCESS_TOKEN_EXPIRE_DAYS) + access_token = create_access_token( + data={"sub": user.user, "scopes": form_data.scopes}, + expires_delta=access_token_expires, + ) + return {"access_token": access_token, "token_type": "bearer"} + + +@app.get("/users/me/", response_model=User) +async def read_users_me(current_user: User = Depends(get_current_active_user)): + return current_user + + +@app.post("/users", response_class=Response) +async def create_users(user: str, email: str, password: str): + col_users.insert_one( {"user": user, "email": email, "hash": get_password_hash(password), "disabled": True} ) + return Response(status_code=HTTP_204_NO_CONTENT) + + +@app.get("/users/me/items/") +async def read_own_items( + current_user: User = Security(get_current_active_user, scopes=["items"]) +): + return [{"item_id": "Foo", "owner": current_user.user}] + + +@app.get("/status/") +async def read_system_status(current_user: User = Depends(get_current_user)): + return {"status": "ok"} \ No newline at end of file diff --git a/modules/security.py b/modules/security.py new file mode 100644 index 0000000..0511120 --- /dev/null +++ b/modules/security.py @@ -0,0 +1,143 @@ +from datetime import datetime, timedelta +from typing import List, Union +from modules.database import col_users +from modules.app import app + +from fastapi import Depends, HTTPException, Security, status +from starlette.status import HTTP_204_NO_CONTENT +from fastapi.security import ( + OAuth2PasswordBearer, + SecurityScopes, +) +from jose import JWTError, jwt +from passlib.context import CryptContext +from pydantic import BaseModel, ValidationError + + +with open("secret_key", "r", encoding="utf-8") as f: + SECRET_KEY = f.read() +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_DAYS = 180 + + +fake_users_db = { + "johndoe": { + "user": "johndoe", + "email": "johndoe@example.com", + "hash": "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW", + "disabled": False, + }, + "alice": { + "user": "alice", + "email": "alicechains@example.com", + "hash": "$2b$12$gSvqqUPvlXP2tfVFaWK1Be7DlH.PKZbv5H8KnzzVgXXbVxpva.pFm", + "disabled": True, + }, +} + + +class Token(BaseModel): + access_token: str + token_type: str + + +class TokenData(BaseModel): + user: Union[str, None] = None + scopes: List[str] = [] + + +class User(BaseModel): + user: str + email: Union[str, None] = None + disabled: Union[bool, None] = None + + +class UserInDB(User): + hash: str + + +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + +oauth2_scheme = OAuth2PasswordBearer( + tokenUrl="token", + scopes={ + "me": "Get current user's data.", + "list": "List albums and images.", + "read": "View albums and images.", + "write": "Manage albums and images."}, +) + + +def verify_password(plain_password, hashed_password): + return pwd_context.verify(plain_password, hashed_password) + + +def get_password_hash(password): + return pwd_context.hash(password) + + +def get_user(user: str): + found_user = col_users.find_one( {"user": user} ) + return UserInDB(user=found_user["user"], email=found_user["email"], disabled=found_user["disabled"], hash=found_user["hash"]) + + +def authenticate_user(user_name: str, password: str): + user = get_user(user_name) + if not user: + return False + if not verify_password(password, user.hash): + return False + return user + + +def create_access_token(data: dict, expires_delta: Union[timedelta, None] = None): + to_encode = data.copy() + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta(days=ACCESS_TOKEN_EXPIRE_DAYS) + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + return encoded_jwt + + +async def get_current_user( + security_scopes: SecurityScopes, token: str = Depends(oauth2_scheme) +): + if security_scopes.scopes: + authenticate_value = f'Bearer scope="{security_scopes.scope_str}"' + else: + authenticate_value = "Bearer" + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": authenticate_value}, + ) + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + user: str = payload.get("sub") + if user is None: + raise credentials_exception + token_scopes = payload.get("scopes", []) + token_data = TokenData(scopes=token_scopes, user=user) + except (JWTError, ValidationError): + raise credentials_exception + user = get_user(user=token_data.user) + if user is None: + raise credentials_exception + for scope in security_scopes.scopes: + if scope not in token_data.scopes: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not enough permissions", + headers={"WWW-Authenticate": authenticate_value}, + ) + return user + + +async def get_current_active_user( + current_user: User = Security(get_current_user, scopes=["me"]) +): + if current_user.disabled: + raise HTTPException(status_code=400, detail="Inactive user") + return current_user \ No newline at end of file