MongoDB migration
This commit is contained in:
@@ -1,11 +1,14 @@
|
||||
from os import path
|
||||
from typing import Union
|
||||
from fastapi import FastAPI, Security, HTTPException
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
|
||||
from fastapi.security import APIKeyQuery, APIKeyHeader, APIKeyCookie
|
||||
from fastapi.openapi.models import APIKey
|
||||
from fastapi.openapi.docs import get_swagger_ui_html, get_redoc_html
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED
|
||||
|
||||
from modules.utils import configGet, jsonLoad
|
||||
from modules.security import passEncode
|
||||
from modules.database import col_apikeys, col_expired
|
||||
|
||||
app = FastAPI(title="Stardew Sync", docs_url=None, redoc_url=None, version="0.1")
|
||||
|
||||
@@ -40,11 +43,8 @@ async def get_api_key(
|
||||
api_key_cookie: str = Security(api_key_cookie),
|
||||
) -> str:
|
||||
|
||||
keys = get_all_api_keys()
|
||||
expired = get_all_expired_keys()
|
||||
|
||||
def is_valid(key):
|
||||
return True if key in keys else False
|
||||
return True if col_apikeys.find_one({"hash": passEncode(key)}) is not None else False
|
||||
|
||||
if is_valid(api_key_query):
|
||||
return api_key_query
|
||||
@@ -53,11 +53,15 @@ async def get_api_key(
|
||||
elif is_valid(api_key_cookie):
|
||||
return api_key_cookie
|
||||
else:
|
||||
if (api_key_query in expired) or (api_key_header in expired) or (api_key_cookie in expired):
|
||||
if (col_expired.find_one({"hash": passEncode(api_key_query)}) is not None) or (col_expired.find_one({"hash": passEncode(api_key_header)}) is not None) or (col_expired.find_one({"hash": passEncode(api_key_cookie)}) is not None):
|
||||
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail=configGet("key_expired", "messages"))
|
||||
else:
|
||||
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail=configGet("key_invalid", "messages"))
|
||||
|
||||
def user_by_key(apikey: Union[str, APIKey]) -> Union[str, None]:
|
||||
db_key = col_apikeys.find_one({"hash": passEncode(apikey)})
|
||||
return db_key["user"] if db_key is not None else None
|
||||
|
||||
@app.get("/docs", include_in_schema=False)
|
||||
async def custom_swagger_ui_html():
|
||||
return get_swagger_ui_html(
|
||||
|
34
modules/database.py
Normal file
34
modules/database.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from modules.utils import configGet
|
||||
from pymongo import MongoClient
|
||||
|
||||
db_config = configGet("database")
|
||||
|
||||
if db_config["user"] is not None and db_config["password"] is not None:
|
||||
con_string = 'mongodb://{0}:{1}@{2}:{3}/{4}'.format(
|
||||
db_config["user"],
|
||||
db_config["password"],
|
||||
db_config["host"],
|
||||
db_config["port"],
|
||||
db_config["name"]
|
||||
)
|
||||
else:
|
||||
con_string = 'mongodb://{0}:{1}/{2}'.format(
|
||||
db_config["host"],
|
||||
db_config["port"],
|
||||
db_config["name"]
|
||||
)
|
||||
|
||||
db_client = MongoClient(con_string)
|
||||
|
||||
db = db_client.get_database(name=db_config["name"])
|
||||
|
||||
collections = db.list_collection_names()
|
||||
|
||||
for collection in ["saves", "devices", "apikeys", "expired"]:
|
||||
if not collection in collections:
|
||||
db.create_collection(collection)
|
||||
|
||||
col_saves = db.get_collection("saves")
|
||||
col_devices = db.get_collection("devices")
|
||||
col_apikeys = db.get_collection("apikeys")
|
||||
col_expired = db.get_collection("expired")
|
17
modules/security.py
Normal file
17
modules/security.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from hashlib import pbkdf2_hmac
|
||||
from os import chmod, path, urandom
|
||||
from typing import Union
|
||||
from modules.utils import configGet
|
||||
from fastapi.openapi.models import APIKey
|
||||
|
||||
def saltRead():
|
||||
if not path.exists(path.join(configGet("data", "locations"), "salt")):
|
||||
with open(path.join(configGet("data", "locations"), "salt"), "wb") as file:
|
||||
file.write(urandom(32))
|
||||
chmod(path.join(configGet("data", "locations"), "salt"), mode=0o600)
|
||||
with open(path.join(configGet("data", "locations"), "salt"), "rb") as file:
|
||||
contents = file.read()
|
||||
return contents
|
||||
|
||||
def passEncode(password: Union[str, APIKey, None]) -> Union[bytes, None]:
|
||||
return None if password is None else pbkdf2_hmac("sha256", str(password).encode("utf-8"), saltRead(), 96800, dklen=128)
|
@@ -1,4 +1,6 @@
|
||||
from typing import Any, Union
|
||||
from os import makedirs, path
|
||||
from typing import Any, Tuple, Union
|
||||
from uuid import uuid4
|
||||
from ujson import loads, dumps, JSONDecodeError
|
||||
from traceback import print_exc
|
||||
|
||||
@@ -58,4 +60,19 @@ def configGet(key: str, *args: str) -> Any:
|
||||
this_key = this_dict
|
||||
for dict_key in args:
|
||||
this_key = this_key[dict_key]
|
||||
return this_key[key]
|
||||
return this_key[key]
|
||||
|
||||
def saveFile(filebytes: bytes) -> Tuple[str, str]:
|
||||
"""Save some bytedata into random file and return its ID
|
||||
|
||||
### Args:
|
||||
* filebytes (`bytes`): Bytes to write into file
|
||||
|
||||
### Returns:
|
||||
* `Tuple[str, str]`: Tuple where first item is an ID and the second is an absolute path to file
|
||||
"""
|
||||
makedirs(path.join(configGet("data", "locations"), "files"), exist_ok=True)
|
||||
filename = str(uuid4())
|
||||
with open(path.join(configGet("data", "locations"), "files", filename), "wb") as file:
|
||||
file.write(filebytes)
|
||||
return filename, path.join(configGet("data", "locations"), "files", filename)
|
Reference in New Issue
Block a user