MongoDB migration

This commit is contained in:
2023-01-18 14:25:22 +01:00
parent 94f8839e53
commit e9e9e3784a
19 changed files with 462 additions and 138 deletions

View File

@@ -1,11 +1,14 @@
from os import path
from typing import Union
from fastapi import FastAPI, Security, HTTPException
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
from fastapi.security import APIKeyQuery, APIKeyHeader, APIKeyCookie
from fastapi.openapi.models import APIKey
from fastapi.openapi.docs import get_swagger_ui_html, get_redoc_html
from starlette.status import HTTP_401_UNAUTHORIZED
from modules.utils import configGet, jsonLoad
from modules.security import passEncode
from modules.database import col_apikeys, col_expired
app = FastAPI(title="Stardew Sync", docs_url=None, redoc_url=None, version="0.1")
@@ -40,11 +43,8 @@ async def get_api_key(
api_key_cookie: str = Security(api_key_cookie),
) -> str:
keys = get_all_api_keys()
expired = get_all_expired_keys()
def is_valid(key):
return True if key in keys else False
return True if col_apikeys.find_one({"hash": passEncode(key)}) is not None else False
if is_valid(api_key_query):
return api_key_query
@@ -53,11 +53,15 @@ async def get_api_key(
elif is_valid(api_key_cookie):
return api_key_cookie
else:
if (api_key_query in expired) or (api_key_header in expired) or (api_key_cookie in expired):
if (col_expired.find_one({"hash": passEncode(api_key_query)}) is not None) or (col_expired.find_one({"hash": passEncode(api_key_header)}) is not None) or (col_expired.find_one({"hash": passEncode(api_key_cookie)}) is not None):
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail=configGet("key_expired", "messages"))
else:
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail=configGet("key_invalid", "messages"))
def user_by_key(apikey: Union[str, APIKey]) -> Union[str, None]:
db_key = col_apikeys.find_one({"hash": passEncode(apikey)})
return db_key["user"] if db_key is not None else None
@app.get("/docs", include_in_schema=False)
async def custom_swagger_ui_html():
return get_swagger_ui_html(

34
modules/database.py Normal file
View File

@@ -0,0 +1,34 @@
from modules.utils import configGet
from pymongo import MongoClient
db_config = configGet("database")
if db_config["user"] is not None and db_config["password"] is not None:
con_string = 'mongodb://{0}:{1}@{2}:{3}/{4}'.format(
db_config["user"],
db_config["password"],
db_config["host"],
db_config["port"],
db_config["name"]
)
else:
con_string = 'mongodb://{0}:{1}/{2}'.format(
db_config["host"],
db_config["port"],
db_config["name"]
)
db_client = MongoClient(con_string)
db = db_client.get_database(name=db_config["name"])
collections = db.list_collection_names()
for collection in ["saves", "devices", "apikeys", "expired"]:
if not collection in collections:
db.create_collection(collection)
col_saves = db.get_collection("saves")
col_devices = db.get_collection("devices")
col_apikeys = db.get_collection("apikeys")
col_expired = db.get_collection("expired")

17
modules/security.py Normal file
View File

@@ -0,0 +1,17 @@
from hashlib import pbkdf2_hmac
from os import chmod, path, urandom
from typing import Union
from modules.utils import configGet
from fastapi.openapi.models import APIKey
def saltRead():
if not path.exists(path.join(configGet("data", "locations"), "salt")):
with open(path.join(configGet("data", "locations"), "salt"), "wb") as file:
file.write(urandom(32))
chmod(path.join(configGet("data", "locations"), "salt"), mode=0o600)
with open(path.join(configGet("data", "locations"), "salt"), "rb") as file:
contents = file.read()
return contents
def passEncode(password: Union[str, APIKey, None]) -> Union[bytes, None]:
return None if password is None else pbkdf2_hmac("sha256", str(password).encode("utf-8"), saltRead(), 96800, dklen=128)

View File

@@ -1,4 +1,6 @@
from typing import Any, Union
from os import makedirs, path
from typing import Any, Tuple, Union
from uuid import uuid4
from ujson import loads, dumps, JSONDecodeError
from traceback import print_exc
@@ -58,4 +60,19 @@ def configGet(key: str, *args: str) -> Any:
this_key = this_dict
for dict_key in args:
this_key = this_key[dict_key]
return this_key[key]
return this_key[key]
def saveFile(filebytes: bytes) -> Tuple[str, str]:
"""Save some bytedata into random file and return its ID
### Args:
* filebytes (`bytes`): Bytes to write into file
### Returns:
* `Tuple[str, str]`: Tuple where first item is an ID and the second is an absolute path to file
"""
makedirs(path.join(configGet("data", "locations"), "files"), exist_ok=True)
filename = str(uuid4())
with open(path.join(configGet("data", "locations"), "files", filename), "wb") as file:
file.write(filebytes)
return filename, path.join(configGet("data", "locations"), "files", filename)