add authentication

add built-in cleanup
general refactor
This commit is contained in:
CEF Server 2024-07-29 03:26:14 +00:00
parent 20ee543ab4
commit ba2e896813
15 changed files with 313 additions and 106 deletions

View file

@ -1,15 +1,46 @@
from fastapi import FastAPI, UploadFile, Request, Depends
import datetime
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import mimetypes
from .auth import JWTBearer
from .sql import SessionMaker, Uploads
from . import util
from .sql import SessionMaker, Uploads, Users
import config
from datetime import datetime, timedelta
from . import endpoints
from .util import minioClient, ergo
from fastapi_utils.tasks import repeat_every
from sqlalchemy import and_
from contextlib import asynccontextmanager
app = FastAPI()
@repeat_every(seconds=60 * 60)
async def cleanup():
sess = SessionMaker()
toCleanup = sess.query(Uploads).filter(Uploads.expiry < datetime.datetime.utcnow()).all()
for upload in toCleanup:
minioClient.remove_object("uploads", upload.hash)
sess.delete(upload)
print(f"Deleted {len(toCleanup)} old files")
tempUsers = sess.query(Users).filter(and_(
Users.temporary == True,
Users.created_at < (datetime.datetime.utcnow() - datetime.timedelta(days=1))
)).all()
for user in tempUsers:
await ergo.fullyRemoveUser(user.username)
sess.delete(user)
print(f"Removed {len(tempUsers)} temp users")
sess.commit()
sess.close()
@asynccontextmanager
async def onStart(app: FastAPI):
await cleanup()
yield
app = FastAPI(lifespan=onStart)
app.add_middleware(
CORSMiddleware,
allow_origins=config.ALLOWED_DOMAINS,
@ -20,26 +51,4 @@ app.add_middleware(
app.include_router(endpoints.router)
@app.post("/upload", dependencies=[Depends(JWTBearer())])
async def upload(file: UploadFile, request: Request):
if file.size > config.MAX_FILE_SIZE:
return {"error": "file too big"}
spl = file.filename.rsplit(".", 1)
safeFilename = util.safeName.sub("_", spl[0])
if len(spl) == 2:
safeFilename += "." + util.safeName.sub("_", spl[1])
sha = await util.SHA256(file)
session = SessionMaker()
if existing := session.query(Uploads).where(Uploads.hash == sha).first():
existing.expiry = datetime.now() + timedelta(days=7)
else:
mime = mimetypes.guess_type(safeFilename)
util.minioClient.put_object("uploads", sha, file.file, file.size, content_type=mime[0])
up = Uploads(hash=sha)
session.add(up)
session.commit()
return {"url": f"https://{config.MINIO_EXTERNAL_ADDR}/uploads/{sha}/{safeFilename}"}
__all__ = ["sql", "auth", "util"]
__all__ = ["sql", "auth", "util"]

View file

@ -5,6 +5,7 @@ from fastapi.security import HTTPBearer
import config
from fastapi import Request, HTTPException
from .sql import SessionMaker, Users
JWT_PUBKEY = open(config.SECRETKEY).read()
JWT_ALGORITHM = "RS256"
@ -19,8 +20,9 @@ def decodeJWT(token: str) -> dict:
class JWTBearer(HTTPBearer):
def __init__(self, auto_error: bool = True):
def __init__(self, account_required=True, auto_error: bool = True):
super(JWTBearer, self).__init__(auto_error=auto_error)
self.accountRequired = account_required
async def __call__(self, request: Request):
credentials = await super(JWTBearer, self).__call__(request)
@ -30,6 +32,16 @@ class JWTBearer(HTTPBearer):
if not self.verify_jwt(credentials.credentials):
raise HTTPException(status_code=403, detail="Invalid or expired token.")
request.state.jwt = decodeJWT(credentials.credentials)
if self.accountRequired:
with SessionMaker() as session:
sess = SessionMaker()
query = sess.query(Users).filter(Users.username == str(request.state.jwt["account"]))
first = query.first()
if first:
if first.temporary:
raise HTTPException(status_code=403, detail="Temporary users can't do this")
else:
raise HTTPException(status_code=403, detail="Somehow you have a valid JWT but no account")
return credentials.credentials
else:
raise HTTPException(status_code=403, detail="Invalid authorization code.")

View file

@ -1,9 +1,10 @@
from fastapi import APIRouter
import os
import importlib
router = APIRouter()
for module in os.listdir(os.path.dirname(__file__)):
if module == '__init__.py' or module[-3:] != '.py':
continue
importlib.import_module("."+module[:-3], package="cef_3M.endpoints")
importlib.import_module("." + module[:-3], package="cef_3M.endpoints")

View file

@ -1,4 +1,132 @@
from . import router
from fastapi import UploadFile, Request, Depends
import random
from pydantic import BaseModel
from . import router
from starlette.responses import JSONResponse
from fastapi import Request, HTTPException, Depends
from ..sql import SessionMaker, Users
from ..util import privilegedIps
from ..auth import JWTBearer
import nacl.pwhash
import string
@router.get("/account/exists/{name}")
async def exists(name: str):
with SessionMaker() as session:
check = session.query(Users).filter(Users.username == str(name))
first = check.first()
if first is None:
return JSONResponse({
"exists": False,
"temporary": False
})
else:
return JSONResponse({
"exists": True,
"temporary": bool(first.temporary)
})
class PasswordChange(BaseModel):
currentPassword: str
newPassword: str
newPasswordAgain: str
@router.get("/account/invite", dependencies=[Depends(JWTBearer())])
async def getInvite(request: Request):
username = request.state.jwt["account"]
with SessionMaker() as session:
user = session.query(Users).filter(Users.username == username).first()
return JSONResponse({
"code": user.invite_code
})
@router.post("/account/invite/regenerate", dependencies=[Depends(JWTBearer())])
async def regenInvite(request: Request):
username = request.state.jwt["account"]
code = ""
for _ in range(8):
code += random.choice(string.ascii_uppercase)
with SessionMaker() as session:
user = session.query(Users).filter(Users.username == username).first()
user.invite_code = code
session.commit()
return JSONResponse({
"code": code
})
@router.post("/account/password", dependencies=[Depends(JWTBearer(False))])
async def changePassword(request: Request, passwordData: PasswordChange):
if passwordData.newPassword != passwordData.newPasswordAgain:
raise HTTPException(status_code=400, detail="Passwords don't match")
if len(passwordData.newPassword) <= 5:
raise HTTPException(status_code=400, detail="Come on, at least longer than 5 characters")
whoami = request.state.jwt
username = whoami["account"].lower()
with SessionMaker() as session:
user = session.query(Users).filter(Users.username == username).first()
bPassOld = passwordData.currentPassword.encode("utf8")
try:
nacl.pwhash.scrypt.verify(user.password.encode("utf8"), bPassOld)
except:
raise HTTPException(status_code=403, detail="Invalid original password")
bPass = passwordData.newPassword.encode("utf8")
user.password = nacl.pwhash.scrypt.str(bPass)
user.temporary = False
session.commit()
return JSONResponse({
"success": True
})
@router.post("/account/verify", include_in_schema=False)
async def verify(request: Request):
if request.client.host not in privilegedIps:
return False
body = await request.json()
bPass = body.get("passphrase", "").encode("utf8")
with SessionMaker() as session:
check = session.query(Users).filter(Users.username == str(body["accountName"]))
first = check.first()
if first:
try:
nacl.pwhash.scrypt.verify(first.password.encode("utf8"), bPass)
return JSONResponse({
"success": True,
})
except:
return JSONResponse({
"success": False,
"error": "Incorrect password"
})
else:
# create account
split = bPass.split(b"|")
if len(split) != 2:
return JSONResponse({
"success": False,
"error": "No invite code"
})
code, password = split
firstUser = False
if session.query(Users).count() == 0:
firstUser = True
inviteFrom = session.query(Users).filter(Users.invite_code == code.decode("utf8")).first()
if not inviteFrom and not firstUser:
return JSONResponse({
"success": False,
"error": "Bad invite code"
})
print("invite code", code, "password", password)
account = Users(username=body["accountName"], password=nacl.pwhash.scrypt.str(password), temporary=True)
session.add(account)
session.commit()
return JSONResponse({
"success": True,
})

View file

@ -3,7 +3,6 @@ from starlette.responses import JSONResponse
from . import router
from .. import JWTBearer
from ..sql import SessionMaker, AlertEndpoints
from ..util import redis
from pydantic import BaseModel, create_model, HttpUrl
@ -14,38 +13,42 @@ class SubscriptionData(BaseModel):
@router.post("/alert/register", dependencies=[Depends(JWTBearer())])
async def register(request: Request, subscription: SubscriptionData):
session = SessionMaker()
check = session.query(AlertEndpoints).filter(AlertEndpoints.url == str(subscription.endpoint))
if check.first() is not None:
with SessionMaker() as session:
check = session.query(AlertEndpoints).filter(AlertEndpoints.url == str(subscription.endpoint))
if check.first() is not None:
return JSONResponse({
"success": True
})
info = AlertEndpoints(username=request.state.jwt["account"], url=str(subscription.endpoint),
auth=subscription.keys.auth, p256dh=subscription.keys.p256dh)
session.merge(info)
session.commit()
return JSONResponse({
"success": True
})
info = AlertEndpoints(username=request.state.jwt["account"], url=str(subscription.endpoint), auth=subscription.keys.auth, p256dh=subscription.keys.p256dh)
session.merge(info)
session.commit()
return JSONResponse({
"success": True
})
@router.post("/alert/unregister", dependencies=[Depends(JWTBearer())])
async def unregister(request: Request):
session = SessionMaker()
body = await request.json()
session.query(AlertEndpoints).filter(AlertEndpoints.url == body.get("url", ""), AlertEndpoints.username == request.state.jwt["account"]).delete()
session.commit()
with SessionMaker() as session:
body = await request.json()
session.query(AlertEndpoints).filter(AlertEndpoints.url == body.get("url", ""),
AlertEndpoints.username == request.state.jwt["account"]).delete()
session.commit()
return JSONResponse({
"success": True
})
return JSONResponse({
"success": True
})
@router.post("/alert/clear", dependencies=[Depends(JWTBearer())])
async def clear(request: Request):
session = SessionMaker()
session.query(AlertEndpoints).filter(AlertEndpoints.username == request.state.jwt["account"]).delete()
session.commit()
with SessionMaker() as session:
session.query(AlertEndpoints).filter(AlertEndpoints.username == request.state.jwt["account"]).delete()
session.commit()
return JSONResponse({
"success": True
})
return JSONResponse({
"success": True
})

View file

@ -4,7 +4,7 @@ from . import router
from fastapi import Request, Depends
from ..auth import decodeJWT, JWTBearer
from ..util import redis, ergo
from ..util import redis, ergo, privilegedIps
def pathParts(path):
@ -35,7 +35,7 @@ async def mediamtxChannelStreams(request: Request, channel: str):
@router.post("/mediamtx/auth", include_in_schema=False)
async def mediamtxAuth(request: Request):
if request.client.host != "127.0.0.1":
if request.client.host not in privilegedIps:
return False
body = await request.json()
jwt = decodeJWT(body["query"][4:])
@ -69,7 +69,7 @@ async def mediamtxAuth(request: Request):
@router.post("/mediamtx/add", include_in_schema=False)
async def mediamtxAdd(request: Request):
if request.client.host != "127.0.0.1":
if request.client.host not in privilegedIps:
return False
body = await request.json()
path = body["env"]["MTX_PATH"].split("/")
@ -79,10 +79,9 @@ async def mediamtxAdd(request: Request):
await ergo.broadcastTo(parts[0], "STREAMSTART", parts[0], parts[1], parts[2])
@router.post("/mediamtx/del", include_in_schema=False)
async def mediamtxDelete(request: Request):
if request.client.host != "127.0.0.1":
if request.client.host not in privilegedIps:
return False
body = await request.json()
path = body["env"]["MTX_PATH"].split("/")
@ -90,6 +89,3 @@ async def mediamtxDelete(request: Request):
await redis.delete("stream " + " ".join(parts))
if len(parts) == 3:
await ergo.broadcastTo(parts[0], "STREAMEND", parts[0], parts[1], parts[2])

View file

@ -13,6 +13,7 @@ from pywuffs.aux import (
ImageDecoder,
ImageDecoderConfig,
)
pfpConfig = ImageDecoderConfig()
pfpConfig.max_incl_dimension = 400
pfpConfig.enabled_decoders = [
@ -27,6 +28,7 @@ iconConfig.enabled_decoders = [
ImageDecoderType.PNG,
]
@router.post("/pfp/upload", dependencies=[Depends(JWTBearer())])
async def pfpUpload(file: UploadFile, request: Request):
if file.size > config.MAX_PFP_SIZE:
@ -48,6 +50,7 @@ async def pfpUpload(file: UploadFile, request: Request):
await ergo.broadcastAs(username, "CACHEBUST")
return {"url": f"https://{config.MINIO_EXTERNAL_ADDR}/pfp/{username}?{time.time():.0f}"}
@router.post("/pfp/uploadIcon", dependencies=[Depends(JWTBearer())])
async def IconUpload(file: UploadFile, request: Request):
if file.size > config.MAX_PFP_SIZE:
@ -63,6 +66,6 @@ async def IconUpload(file: UploadFile, request: Request):
file.file.seek(0)
mime = mimetypes.guess_type(file.filename)
minioClient.put_object("pfp", username+".icon", file.file, file.size, content_type=mime[0])
minioClient.put_object("pfp", username + ".icon", file.file, file.size, content_type=mime[0])
await ergo.broadcastAs(username, "CACHEBUST")
return {"url": f"https://{config.MINIO_EXTERNAL_ADDR}/pfp/{username}.icon?{time.time():.0f}"}
return {"url": f"https://{config.MINIO_EXTERNAL_ADDR}/pfp/{username}.icon?{time.time():.0f}"}

View file

@ -0,0 +1,29 @@
from . import router
from fastapi import UploadFile, Request, Depends
from ..sql import SessionMaker, Uploads
from datetime import datetime, timedelta
from .. import util
from ..auth import JWTBearer
import config
import mimetypes
@router.post("/upload", dependencies=[Depends(JWTBearer(False))])
async def upload(file: UploadFile, request: Request):
if file.size > config.MAX_FILE_SIZE:
return {"error": "file too big"}
spl = file.filename.rsplit(".", 1)
safeFilename = util.safeName.sub("_", spl[0])
if len(spl) == 2:
safeFilename += "." + util.safeName.sub("_", spl[1])
sha = await util.SHA256(file)
with SessionMaker() as session:
if existing := session.query(Uploads).where(Uploads.hash == sha).first():
existing.expiry = datetime.now() + timedelta(days=7)
else:
mime = mimetypes.guess_type(safeFilename)
util.minioClient.put_object("uploads", sha, file.file, file.size, content_type=mime[0])
up = Uploads(hash=sha)
session.add(up)
session.commit()
return {"url": f"https://{config.MINIO_EXTERNAL_ADDR}/uploads/{sha}/{safeFilename}"}

View file

@ -8,10 +8,7 @@ import configparser
alembic = configparser.ConfigParser()
alembic.read("alembic.ini")
try:
dburl = alembic.get("alembic", "sqlalchemy.url")
except:
dburl = config.DBURL
dburl = config.DBURL
engine = create_engine(
dburl,
@ -27,7 +24,8 @@ ergoEngine = create_engine(
SessionMaker = sessionmaker(autocommit=False, autoflush=False, bind=engine, )
def ergoQueryFetchOne(q: str, **kwargs):
with ergoEngine.connect() as connection:
connection.execute(text("use ergo"))
return connection.execute(text(q), kwargs).fetchone()
return connection.execute(text(q), kwargs).fetchone()

View file

@ -38,3 +38,4 @@ class Users(Base):
password: Mapped[str] = mapped_column(String(128))
created_at: Mapped[Optional[datetime.datetime]] = mapped_column(TIMESTAMP, server_default=text('current_timestamp()'))
temporary: Mapped[Optional[int]] = mapped_column(TINYINT(1), server_default=text('1'))
invite_code: Mapped[Optional[str]] = mapped_column(String(32))

View file

@ -2,6 +2,8 @@ import asyncio
import hashlib
import json
import re
import socket
import traceback
import MySQLdb
@ -14,9 +16,9 @@ from .sql import SessionMaker, AlertEndpoints, ergoQueryFetchOne
from fastapi import UploadFile
safeName = re.compile(r"[^\w\d\.-]")
# If this gets too out of hand, put an async breakpoint to allow other things to be handled while the hash occurs
async def SHA256(f: UploadFile) -> str:
sha = hashlib.sha256()
@ -25,6 +27,7 @@ async def SHA256(f: UploadFile) -> str:
await f.seek(0)
return sha.hexdigest()
minioClient = Minio(
config.MINIO_INTERNAL_ADDR,
secure=False, # you will probably not have SSL
@ -32,7 +35,8 @@ minioClient = Minio(
secret_key=config.MINIO_SECRET_KEY,
)
redis = Redis(host='localhost', port=6379, db=0, protocol=3)
redis = Redis(host=config.REDIS_ADDR, port=6379, db=0, protocol=3)
class ErgoClient:
def __init__(self):
@ -40,34 +44,19 @@ class ErgoClient:
self.writer = None
asyncio.get_running_loop().create_task(self.init())
@staticmethod
def retry(f):
async def wrapper(self, *args, **kwargs):
i = 30
while i:
try:
return await f(self, *args, **kwargs)
except RuntimeError:
self.init()
i -= 1
print("Couldn't connect")
return wrapper
@retry
async def init(self):
self.reader, self.writer = await asyncio.open_connection(config.ERGO_ADDR, config.ERGO_PORT)
await asyncio.get_running_loop().create_task(self.readEvents())
await asyncio.create_task(self.readEvents())
@retry
async def readEvents(self):
while 1:
rawLine = await self.reader.readline()
if not rawLine: break
if not rawLine:
break
line = rawLine.decode("utf8").strip().split()
if line[0] == "MENTION":
await self.handleMention(line[1], line[2], line[3])
async def handleMention(self, username: str, channel: str, msgid: str):
session = SessionMaker()
for target in session.query(AlertEndpoints).filter(AlertEndpoints.username == username):
@ -89,19 +78,28 @@ class ErgoClient:
await pusher.send_async(encoded)
session.close()
@retry
async def write(self, msg):
self.writer.write(msg+b"\n")
if self.writer is None:
for _ in range(30):
await asyncio.sleep(1)
if self.writer:
break
self.writer.write(msg + b"\n")
await self.writer.drain()
async def broadcastAs(self, user, *message):
await self.write(f"BROADCASTAS {user} {' '.join(message)}".encode("utf8"))
async def fullyRemoveUser(self, user):
await self.write(f"FULLYREMOVE {user}".encode("utf8"))
async def broadcastTo(self, user, *message):
await self.write(f"BROADCASTTO {user} {' '.join(message)}".encode("utf8"))
ergo = ErgoClient()
privilegedIps = set()
for host in config.PRIVILEGED_HOSTS:
for addr in [x[-1][0] for x in socket.getaddrinfo("localhost", 0)]:
privilegedIps.add(addr)