add authentication
add built-in cleanup general refactor
This commit is contained in:
parent
20ee543ab4
commit
ba2e896813
15 changed files with 313 additions and 106 deletions
26
alembic/versions/3c43e544e939_create_invites.py
Normal file
26
alembic/versions/3c43e544e939_create_invites.py
Normal file
|
|
@ -0,0 +1,26 @@
|
||||||
|
"""create invites
|
||||||
|
|
||||||
|
Revision ID: 3c43e544e939
|
||||||
|
Revises: f73144466860
|
||||||
|
Create Date: 2024-07-29 01:20:12.377093
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '3c43e544e939'
|
||||||
|
down_revision: Union[str, None] = 'f73144466860'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column("users", sa.Column("invite_code", sa.String(32), default=None))
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column("users", "invite_code")
|
||||||
|
|
@ -1,15 +1,46 @@
|
||||||
from fastapi import FastAPI, UploadFile, Request, Depends
|
import datetime
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
import mimetypes
|
|
||||||
|
|
||||||
from .auth import JWTBearer
|
from .auth import JWTBearer
|
||||||
from .sql import SessionMaker, Uploads
|
from .sql import SessionMaker, Uploads, Users
|
||||||
from . import util
|
|
||||||
import config
|
import config
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from . import endpoints
|
from . import endpoints
|
||||||
|
from .util import minioClient, ergo
|
||||||
|
from fastapi_utils.tasks import repeat_every
|
||||||
|
from sqlalchemy import and_
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
|
@repeat_every(seconds=60 * 60)
|
||||||
|
async def cleanup():
|
||||||
|
sess = SessionMaker()
|
||||||
|
toCleanup = sess.query(Uploads).filter(Uploads.expiry < datetime.datetime.utcnow()).all()
|
||||||
|
for upload in toCleanup:
|
||||||
|
minioClient.remove_object("uploads", upload.hash)
|
||||||
|
sess.delete(upload)
|
||||||
|
print(f"Deleted {len(toCleanup)} old files")
|
||||||
|
|
||||||
|
tempUsers = sess.query(Users).filter(and_(
|
||||||
|
Users.temporary == True,
|
||||||
|
Users.created_at < (datetime.datetime.utcnow() - datetime.timedelta(days=1))
|
||||||
|
)).all()
|
||||||
|
for user in tempUsers:
|
||||||
|
await ergo.fullyRemoveUser(user.username)
|
||||||
|
sess.delete(user)
|
||||||
|
print(f"Removed {len(tempUsers)} temp users")
|
||||||
|
sess.commit()
|
||||||
|
sess.close()
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def onStart(app: FastAPI):
|
||||||
|
await cleanup()
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(lifespan=onStart)
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=config.ALLOWED_DOMAINS,
|
allow_origins=config.ALLOWED_DOMAINS,
|
||||||
|
|
@ -20,26 +51,4 @@ app.add_middleware(
|
||||||
|
|
||||||
app.include_router(endpoints.router)
|
app.include_router(endpoints.router)
|
||||||
|
|
||||||
@app.post("/upload", dependencies=[Depends(JWTBearer())])
|
|
||||||
async def upload(file: UploadFile, request: Request):
|
|
||||||
if file.size > config.MAX_FILE_SIZE:
|
|
||||||
return {"error": "file too big"}
|
|
||||||
spl = file.filename.rsplit(".", 1)
|
|
||||||
safeFilename = util.safeName.sub("_", spl[0])
|
|
||||||
if len(spl) == 2:
|
|
||||||
safeFilename += "." + util.safeName.sub("_", spl[1])
|
|
||||||
sha = await util.SHA256(file)
|
|
||||||
session = SessionMaker()
|
|
||||||
if existing := session.query(Uploads).where(Uploads.hash == sha).first():
|
|
||||||
existing.expiry = datetime.now() + timedelta(days=7)
|
|
||||||
else:
|
|
||||||
mime = mimetypes.guess_type(safeFilename)
|
|
||||||
util.minioClient.put_object("uploads", sha, file.file, file.size, content_type=mime[0])
|
|
||||||
up = Uploads(hash=sha)
|
|
||||||
session.add(up)
|
|
||||||
session.commit()
|
|
||||||
return {"url": f"https://{config.MINIO_EXTERNAL_ADDR}/uploads/{sha}/{safeFilename}"}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["sql", "auth", "util"]
|
__all__ = ["sql", "auth", "util"]
|
||||||
|
|
@ -5,6 +5,7 @@ from fastapi.security import HTTPBearer
|
||||||
import config
|
import config
|
||||||
from fastapi import Request, HTTPException
|
from fastapi import Request, HTTPException
|
||||||
|
|
||||||
|
from .sql import SessionMaker, Users
|
||||||
|
|
||||||
JWT_PUBKEY = open(config.SECRETKEY).read()
|
JWT_PUBKEY = open(config.SECRETKEY).read()
|
||||||
JWT_ALGORITHM = "RS256"
|
JWT_ALGORITHM = "RS256"
|
||||||
|
|
@ -19,8 +20,9 @@ def decodeJWT(token: str) -> dict:
|
||||||
|
|
||||||
|
|
||||||
class JWTBearer(HTTPBearer):
|
class JWTBearer(HTTPBearer):
|
||||||
def __init__(self, auto_error: bool = True):
|
def __init__(self, account_required=True, auto_error: bool = True):
|
||||||
super(JWTBearer, self).__init__(auto_error=auto_error)
|
super(JWTBearer, self).__init__(auto_error=auto_error)
|
||||||
|
self.accountRequired = account_required
|
||||||
|
|
||||||
async def __call__(self, request: Request):
|
async def __call__(self, request: Request):
|
||||||
credentials = await super(JWTBearer, self).__call__(request)
|
credentials = await super(JWTBearer, self).__call__(request)
|
||||||
|
|
@ -30,6 +32,16 @@ class JWTBearer(HTTPBearer):
|
||||||
if not self.verify_jwt(credentials.credentials):
|
if not self.verify_jwt(credentials.credentials):
|
||||||
raise HTTPException(status_code=403, detail="Invalid or expired token.")
|
raise HTTPException(status_code=403, detail="Invalid or expired token.")
|
||||||
request.state.jwt = decodeJWT(credentials.credentials)
|
request.state.jwt = decodeJWT(credentials.credentials)
|
||||||
|
if self.accountRequired:
|
||||||
|
with SessionMaker() as session:
|
||||||
|
sess = SessionMaker()
|
||||||
|
query = sess.query(Users).filter(Users.username == str(request.state.jwt["account"]))
|
||||||
|
first = query.first()
|
||||||
|
if first:
|
||||||
|
if first.temporary:
|
||||||
|
raise HTTPException(status_code=403, detail="Temporary users can't do this")
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=403, detail="Somehow you have a valid JWT but no account")
|
||||||
return credentials.credentials
|
return credentials.credentials
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=403, detail="Invalid authorization code.")
|
raise HTTPException(status_code=403, detail="Invalid authorization code.")
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
import os
|
import os
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
for module in os.listdir(os.path.dirname(__file__)):
|
for module in os.listdir(os.path.dirname(__file__)):
|
||||||
if module == '__init__.py' or module[-3:] != '.py':
|
if module == '__init__.py' or module[-3:] != '.py':
|
||||||
continue
|
continue
|
||||||
importlib.import_module("."+module[:-3], package="cef_3M.endpoints")
|
importlib.import_module("." + module[:-3], package="cef_3M.endpoints")
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,132 @@
|
||||||
from . import router
|
import random
|
||||||
from fastapi import UploadFile, Request, Depends
|
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from . import router
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
from fastapi import Request, HTTPException, Depends
|
||||||
|
from ..sql import SessionMaker, Users
|
||||||
|
from ..util import privilegedIps
|
||||||
from ..auth import JWTBearer
|
from ..auth import JWTBearer
|
||||||
|
import nacl.pwhash
|
||||||
|
import string
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/account/exists/{name}")
|
||||||
|
async def exists(name: str):
|
||||||
|
with SessionMaker() as session:
|
||||||
|
check = session.query(Users).filter(Users.username == str(name))
|
||||||
|
first = check.first()
|
||||||
|
if first is None:
|
||||||
|
return JSONResponse({
|
||||||
|
"exists": False,
|
||||||
|
"temporary": False
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
return JSONResponse({
|
||||||
|
"exists": True,
|
||||||
|
"temporary": bool(first.temporary)
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
class PasswordChange(BaseModel):
|
||||||
|
currentPassword: str
|
||||||
|
newPassword: str
|
||||||
|
newPasswordAgain: str
|
||||||
|
|
||||||
|
@router.get("/account/invite", dependencies=[Depends(JWTBearer())])
|
||||||
|
async def getInvite(request: Request):
|
||||||
|
username = request.state.jwt["account"]
|
||||||
|
with SessionMaker() as session:
|
||||||
|
user = session.query(Users).filter(Users.username == username).first()
|
||||||
|
|
||||||
|
return JSONResponse({
|
||||||
|
"code": user.invite_code
|
||||||
|
})
|
||||||
|
|
||||||
|
@router.post("/account/invite/regenerate", dependencies=[Depends(JWTBearer())])
|
||||||
|
async def regenInvite(request: Request):
|
||||||
|
username = request.state.jwt["account"]
|
||||||
|
code = ""
|
||||||
|
for _ in range(8):
|
||||||
|
code += random.choice(string.ascii_uppercase)
|
||||||
|
with SessionMaker() as session:
|
||||||
|
user = session.query(Users).filter(Users.username == username).first()
|
||||||
|
user.invite_code = code
|
||||||
|
session.commit()
|
||||||
|
return JSONResponse({
|
||||||
|
"code": code
|
||||||
|
})
|
||||||
|
|
||||||
|
@router.post("/account/password", dependencies=[Depends(JWTBearer(False))])
|
||||||
|
async def changePassword(request: Request, passwordData: PasswordChange):
|
||||||
|
if passwordData.newPassword != passwordData.newPasswordAgain:
|
||||||
|
raise HTTPException(status_code=400, detail="Passwords don't match")
|
||||||
|
if len(passwordData.newPassword) <= 5:
|
||||||
|
raise HTTPException(status_code=400, detail="Come on, at least longer than 5 characters")
|
||||||
|
whoami = request.state.jwt
|
||||||
|
username = whoami["account"].lower()
|
||||||
|
|
||||||
|
with SessionMaker() as session:
|
||||||
|
user = session.query(Users).filter(Users.username == username).first()
|
||||||
|
bPassOld = passwordData.currentPassword.encode("utf8")
|
||||||
|
|
||||||
|
try:
|
||||||
|
nacl.pwhash.scrypt.verify(user.password.encode("utf8"), bPassOld)
|
||||||
|
except:
|
||||||
|
raise HTTPException(status_code=403, detail="Invalid original password")
|
||||||
|
|
||||||
|
bPass = passwordData.newPassword.encode("utf8")
|
||||||
|
user.password = nacl.pwhash.scrypt.str(bPass)
|
||||||
|
user.temporary = False
|
||||||
|
session.commit()
|
||||||
|
return JSONResponse({
|
||||||
|
"success": True
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/account/verify", include_in_schema=False)
|
||||||
|
async def verify(request: Request):
|
||||||
|
if request.client.host not in privilegedIps:
|
||||||
|
return False
|
||||||
|
body = await request.json()
|
||||||
|
bPass = body.get("passphrase", "").encode("utf8")
|
||||||
|
with SessionMaker() as session:
|
||||||
|
check = session.query(Users).filter(Users.username == str(body["accountName"]))
|
||||||
|
first = check.first()
|
||||||
|
if first:
|
||||||
|
try:
|
||||||
|
nacl.pwhash.scrypt.verify(first.password.encode("utf8"), bPass)
|
||||||
|
return JSONResponse({
|
||||||
|
"success": True,
|
||||||
|
})
|
||||||
|
except:
|
||||||
|
return JSONResponse({
|
||||||
|
"success": False,
|
||||||
|
"error": "Incorrect password"
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
# create account
|
||||||
|
split = bPass.split(b"|")
|
||||||
|
if len(split) != 2:
|
||||||
|
return JSONResponse({
|
||||||
|
"success": False,
|
||||||
|
"error": "No invite code"
|
||||||
|
})
|
||||||
|
code, password = split
|
||||||
|
firstUser = False
|
||||||
|
if session.query(Users).count() == 0:
|
||||||
|
firstUser = True
|
||||||
|
inviteFrom = session.query(Users).filter(Users.invite_code == code.decode("utf8")).first()
|
||||||
|
if not inviteFrom and not firstUser:
|
||||||
|
return JSONResponse({
|
||||||
|
"success": False,
|
||||||
|
"error": "Bad invite code"
|
||||||
|
})
|
||||||
|
print("invite code", code, "password", password)
|
||||||
|
account = Users(username=body["accountName"], password=nacl.pwhash.scrypt.str(password), temporary=True)
|
||||||
|
session.add(account)
|
||||||
|
session.commit()
|
||||||
|
return JSONResponse({
|
||||||
|
"success": True,
|
||||||
|
})
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@ from starlette.responses import JSONResponse
|
||||||
from . import router
|
from . import router
|
||||||
from .. import JWTBearer
|
from .. import JWTBearer
|
||||||
from ..sql import SessionMaker, AlertEndpoints
|
from ..sql import SessionMaker, AlertEndpoints
|
||||||
from ..util import redis
|
|
||||||
from pydantic import BaseModel, create_model, HttpUrl
|
from pydantic import BaseModel, create_model, HttpUrl
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -14,38 +13,42 @@ class SubscriptionData(BaseModel):
|
||||||
|
|
||||||
@router.post("/alert/register", dependencies=[Depends(JWTBearer())])
|
@router.post("/alert/register", dependencies=[Depends(JWTBearer())])
|
||||||
async def register(request: Request, subscription: SubscriptionData):
|
async def register(request: Request, subscription: SubscriptionData):
|
||||||
session = SessionMaker()
|
with SessionMaker() as session:
|
||||||
check = session.query(AlertEndpoints).filter(AlertEndpoints.url == str(subscription.endpoint))
|
check = session.query(AlertEndpoints).filter(AlertEndpoints.url == str(subscription.endpoint))
|
||||||
if check.first() is not None:
|
if check.first() is not None:
|
||||||
|
return JSONResponse({
|
||||||
|
"success": True
|
||||||
|
})
|
||||||
|
info = AlertEndpoints(username=request.state.jwt["account"], url=str(subscription.endpoint),
|
||||||
|
auth=subscription.keys.auth, p256dh=subscription.keys.p256dh)
|
||||||
|
|
||||||
|
session.merge(info)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
return JSONResponse({
|
return JSONResponse({
|
||||||
"success": True
|
"success": True
|
||||||
})
|
})
|
||||||
info = AlertEndpoints(username=request.state.jwt["account"], url=str(subscription.endpoint), auth=subscription.keys.auth, p256dh=subscription.keys.p256dh)
|
|
||||||
|
|
||||||
session.merge(info)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
return JSONResponse({
|
|
||||||
"success": True
|
|
||||||
})
|
|
||||||
|
|
||||||
@router.post("/alert/unregister", dependencies=[Depends(JWTBearer())])
|
@router.post("/alert/unregister", dependencies=[Depends(JWTBearer())])
|
||||||
async def unregister(request: Request):
|
async def unregister(request: Request):
|
||||||
session = SessionMaker()
|
with SessionMaker() as session:
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
session.query(AlertEndpoints).filter(AlertEndpoints.url == body.get("url", ""), AlertEndpoints.username == request.state.jwt["account"]).delete()
|
session.query(AlertEndpoints).filter(AlertEndpoints.url == body.get("url", ""),
|
||||||
session.commit()
|
AlertEndpoints.username == request.state.jwt["account"]).delete()
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return JSONResponse({
|
||||||
|
"success": True
|
||||||
|
})
|
||||||
|
|
||||||
return JSONResponse({
|
|
||||||
"success": True
|
|
||||||
})
|
|
||||||
|
|
||||||
@router.post("/alert/clear", dependencies=[Depends(JWTBearer())])
|
@router.post("/alert/clear", dependencies=[Depends(JWTBearer())])
|
||||||
async def clear(request: Request):
|
async def clear(request: Request):
|
||||||
session = SessionMaker()
|
with SessionMaker() as session:
|
||||||
session.query(AlertEndpoints).filter(AlertEndpoints.username == request.state.jwt["account"]).delete()
|
session.query(AlertEndpoints).filter(AlertEndpoints.username == request.state.jwt["account"]).delete()
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
return JSONResponse({
|
return JSONResponse({
|
||||||
"success": True
|
"success": True
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from . import router
|
||||||
from fastapi import Request, Depends
|
from fastapi import Request, Depends
|
||||||
|
|
||||||
from ..auth import decodeJWT, JWTBearer
|
from ..auth import decodeJWT, JWTBearer
|
||||||
from ..util import redis, ergo
|
from ..util import redis, ergo, privilegedIps
|
||||||
|
|
||||||
|
|
||||||
def pathParts(path):
|
def pathParts(path):
|
||||||
|
|
@ -35,7 +35,7 @@ async def mediamtxChannelStreams(request: Request, channel: str):
|
||||||
|
|
||||||
@router.post("/mediamtx/auth", include_in_schema=False)
|
@router.post("/mediamtx/auth", include_in_schema=False)
|
||||||
async def mediamtxAuth(request: Request):
|
async def mediamtxAuth(request: Request):
|
||||||
if request.client.host != "127.0.0.1":
|
if request.client.host not in privilegedIps:
|
||||||
return False
|
return False
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
jwt = decodeJWT(body["query"][4:])
|
jwt = decodeJWT(body["query"][4:])
|
||||||
|
|
@ -69,7 +69,7 @@ async def mediamtxAuth(request: Request):
|
||||||
|
|
||||||
@router.post("/mediamtx/add", include_in_schema=False)
|
@router.post("/mediamtx/add", include_in_schema=False)
|
||||||
async def mediamtxAdd(request: Request):
|
async def mediamtxAdd(request: Request):
|
||||||
if request.client.host != "127.0.0.1":
|
if request.client.host not in privilegedIps:
|
||||||
return False
|
return False
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
path = body["env"]["MTX_PATH"].split("/")
|
path = body["env"]["MTX_PATH"].split("/")
|
||||||
|
|
@ -79,10 +79,9 @@ async def mediamtxAdd(request: Request):
|
||||||
await ergo.broadcastTo(parts[0], "STREAMSTART", parts[0], parts[1], parts[2])
|
await ergo.broadcastTo(parts[0], "STREAMSTART", parts[0], parts[1], parts[2])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/mediamtx/del", include_in_schema=False)
|
@router.post("/mediamtx/del", include_in_schema=False)
|
||||||
async def mediamtxDelete(request: Request):
|
async def mediamtxDelete(request: Request):
|
||||||
if request.client.host != "127.0.0.1":
|
if request.client.host not in privilegedIps:
|
||||||
return False
|
return False
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
path = body["env"]["MTX_PATH"].split("/")
|
path = body["env"]["MTX_PATH"].split("/")
|
||||||
|
|
@ -90,6 +89,3 @@ async def mediamtxDelete(request: Request):
|
||||||
await redis.delete("stream " + " ".join(parts))
|
await redis.delete("stream " + " ".join(parts))
|
||||||
if len(parts) == 3:
|
if len(parts) == 3:
|
||||||
await ergo.broadcastTo(parts[0], "STREAMEND", parts[0], parts[1], parts[2])
|
await ergo.broadcastTo(parts[0], "STREAMEND", parts[0], parts[1], parts[2])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ from pywuffs.aux import (
|
||||||
ImageDecoder,
|
ImageDecoder,
|
||||||
ImageDecoderConfig,
|
ImageDecoderConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
pfpConfig = ImageDecoderConfig()
|
pfpConfig = ImageDecoderConfig()
|
||||||
pfpConfig.max_incl_dimension = 400
|
pfpConfig.max_incl_dimension = 400
|
||||||
pfpConfig.enabled_decoders = [
|
pfpConfig.enabled_decoders = [
|
||||||
|
|
@ -27,6 +28,7 @@ iconConfig.enabled_decoders = [
|
||||||
ImageDecoderType.PNG,
|
ImageDecoderType.PNG,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@router.post("/pfp/upload", dependencies=[Depends(JWTBearer())])
|
@router.post("/pfp/upload", dependencies=[Depends(JWTBearer())])
|
||||||
async def pfpUpload(file: UploadFile, request: Request):
|
async def pfpUpload(file: UploadFile, request: Request):
|
||||||
if file.size > config.MAX_PFP_SIZE:
|
if file.size > config.MAX_PFP_SIZE:
|
||||||
|
|
@ -48,6 +50,7 @@ async def pfpUpload(file: UploadFile, request: Request):
|
||||||
await ergo.broadcastAs(username, "CACHEBUST")
|
await ergo.broadcastAs(username, "CACHEBUST")
|
||||||
return {"url": f"https://{config.MINIO_EXTERNAL_ADDR}/pfp/{username}?{time.time():.0f}"}
|
return {"url": f"https://{config.MINIO_EXTERNAL_ADDR}/pfp/{username}?{time.time():.0f}"}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/pfp/uploadIcon", dependencies=[Depends(JWTBearer())])
|
@router.post("/pfp/uploadIcon", dependencies=[Depends(JWTBearer())])
|
||||||
async def IconUpload(file: UploadFile, request: Request):
|
async def IconUpload(file: UploadFile, request: Request):
|
||||||
if file.size > config.MAX_PFP_SIZE:
|
if file.size > config.MAX_PFP_SIZE:
|
||||||
|
|
@ -63,6 +66,6 @@ async def IconUpload(file: UploadFile, request: Request):
|
||||||
file.file.seek(0)
|
file.file.seek(0)
|
||||||
|
|
||||||
mime = mimetypes.guess_type(file.filename)
|
mime = mimetypes.guess_type(file.filename)
|
||||||
minioClient.put_object("pfp", username+".icon", file.file, file.size, content_type=mime[0])
|
minioClient.put_object("pfp", username + ".icon", file.file, file.size, content_type=mime[0])
|
||||||
await ergo.broadcastAs(username, "CACHEBUST")
|
await ergo.broadcastAs(username, "CACHEBUST")
|
||||||
return {"url": f"https://{config.MINIO_EXTERNAL_ADDR}/pfp/{username}.icon?{time.time():.0f}"}
|
return {"url": f"https://{config.MINIO_EXTERNAL_ADDR}/pfp/{username}.icon?{time.time():.0f}"}
|
||||||
29
cef_3M/endpoints/uploads.py
Normal file
29
cef_3M/endpoints/uploads.py
Normal file
|
|
@ -0,0 +1,29 @@
|
||||||
|
from . import router
|
||||||
|
from fastapi import UploadFile, Request, Depends
|
||||||
|
from ..sql import SessionMaker, Uploads
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from .. import util
|
||||||
|
from ..auth import JWTBearer
|
||||||
|
import config
|
||||||
|
import mimetypes
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/upload", dependencies=[Depends(JWTBearer(False))])
|
||||||
|
async def upload(file: UploadFile, request: Request):
|
||||||
|
if file.size > config.MAX_FILE_SIZE:
|
||||||
|
return {"error": "file too big"}
|
||||||
|
spl = file.filename.rsplit(".", 1)
|
||||||
|
safeFilename = util.safeName.sub("_", spl[0])
|
||||||
|
if len(spl) == 2:
|
||||||
|
safeFilename += "." + util.safeName.sub("_", spl[1])
|
||||||
|
sha = await util.SHA256(file)
|
||||||
|
with SessionMaker() as session:
|
||||||
|
if existing := session.query(Uploads).where(Uploads.hash == sha).first():
|
||||||
|
existing.expiry = datetime.now() + timedelta(days=7)
|
||||||
|
else:
|
||||||
|
mime = mimetypes.guess_type(safeFilename)
|
||||||
|
util.minioClient.put_object("uploads", sha, file.file, file.size, content_type=mime[0])
|
||||||
|
up = Uploads(hash=sha)
|
||||||
|
session.add(up)
|
||||||
|
session.commit()
|
||||||
|
return {"url": f"https://{config.MINIO_EXTERNAL_ADDR}/uploads/{sha}/{safeFilename}"}
|
||||||
|
|
@ -8,10 +8,7 @@ import configparser
|
||||||
alembic = configparser.ConfigParser()
|
alembic = configparser.ConfigParser()
|
||||||
alembic.read("alembic.ini")
|
alembic.read("alembic.ini")
|
||||||
|
|
||||||
try:
|
dburl = config.DBURL
|
||||||
dburl = alembic.get("alembic", "sqlalchemy.url")
|
|
||||||
except:
|
|
||||||
dburl = config.DBURL
|
|
||||||
|
|
||||||
engine = create_engine(
|
engine = create_engine(
|
||||||
dburl,
|
dburl,
|
||||||
|
|
@ -27,6 +24,7 @@ ergoEngine = create_engine(
|
||||||
|
|
||||||
SessionMaker = sessionmaker(autocommit=False, autoflush=False, bind=engine, )
|
SessionMaker = sessionmaker(autocommit=False, autoflush=False, bind=engine, )
|
||||||
|
|
||||||
|
|
||||||
def ergoQueryFetchOne(q: str, **kwargs):
|
def ergoQueryFetchOne(q: str, **kwargs):
|
||||||
with ergoEngine.connect() as connection:
|
with ergoEngine.connect() as connection:
|
||||||
connection.execute(text("use ergo"))
|
connection.execute(text("use ergo"))
|
||||||
|
|
|
||||||
|
|
@ -38,3 +38,4 @@ class Users(Base):
|
||||||
password: Mapped[str] = mapped_column(String(128))
|
password: Mapped[str] = mapped_column(String(128))
|
||||||
created_at: Mapped[Optional[datetime.datetime]] = mapped_column(TIMESTAMP, server_default=text('current_timestamp()'))
|
created_at: Mapped[Optional[datetime.datetime]] = mapped_column(TIMESTAMP, server_default=text('current_timestamp()'))
|
||||||
temporary: Mapped[Optional[int]] = mapped_column(TINYINT(1), server_default=text('1'))
|
temporary: Mapped[Optional[int]] = mapped_column(TINYINT(1), server_default=text('1'))
|
||||||
|
invite_code: Mapped[Optional[str]] = mapped_column(String(32))
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,8 @@ import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
import socket
|
||||||
|
import traceback
|
||||||
|
|
||||||
import MySQLdb
|
import MySQLdb
|
||||||
|
|
||||||
|
|
@ -14,9 +16,9 @@ from .sql import SessionMaker, AlertEndpoints, ergoQueryFetchOne
|
||||||
|
|
||||||
from fastapi import UploadFile
|
from fastapi import UploadFile
|
||||||
|
|
||||||
|
|
||||||
safeName = re.compile(r"[^\w\d\.-]")
|
safeName = re.compile(r"[^\w\d\.-]")
|
||||||
|
|
||||||
|
|
||||||
# If this gets too out of hand, put an async breakpoint to allow other things to be handled while the hash occurs
|
# If this gets too out of hand, put an async breakpoint to allow other things to be handled while the hash occurs
|
||||||
async def SHA256(f: UploadFile) -> str:
|
async def SHA256(f: UploadFile) -> str:
|
||||||
sha = hashlib.sha256()
|
sha = hashlib.sha256()
|
||||||
|
|
@ -25,6 +27,7 @@ async def SHA256(f: UploadFile) -> str:
|
||||||
await f.seek(0)
|
await f.seek(0)
|
||||||
return sha.hexdigest()
|
return sha.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
minioClient = Minio(
|
minioClient = Minio(
|
||||||
config.MINIO_INTERNAL_ADDR,
|
config.MINIO_INTERNAL_ADDR,
|
||||||
secure=False, # you will probably not have SSL
|
secure=False, # you will probably not have SSL
|
||||||
|
|
@ -32,7 +35,8 @@ minioClient = Minio(
|
||||||
secret_key=config.MINIO_SECRET_KEY,
|
secret_key=config.MINIO_SECRET_KEY,
|
||||||
)
|
)
|
||||||
|
|
||||||
redis = Redis(host='localhost', port=6379, db=0, protocol=3)
|
redis = Redis(host=config.REDIS_ADDR, port=6379, db=0, protocol=3)
|
||||||
|
|
||||||
|
|
||||||
class ErgoClient:
|
class ErgoClient:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
@ -40,34 +44,19 @@ class ErgoClient:
|
||||||
self.writer = None
|
self.writer = None
|
||||||
asyncio.get_running_loop().create_task(self.init())
|
asyncio.get_running_loop().create_task(self.init())
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def retry(f):
|
|
||||||
async def wrapper(self, *args, **kwargs):
|
|
||||||
i = 30
|
|
||||||
while i:
|
|
||||||
try:
|
|
||||||
return await f(self, *args, **kwargs)
|
|
||||||
except RuntimeError:
|
|
||||||
self.init()
|
|
||||||
i -= 1
|
|
||||||
print("Couldn't connect")
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
@retry
|
|
||||||
async def init(self):
|
async def init(self):
|
||||||
self.reader, self.writer = await asyncio.open_connection(config.ERGO_ADDR, config.ERGO_PORT)
|
self.reader, self.writer = await asyncio.open_connection(config.ERGO_ADDR, config.ERGO_PORT)
|
||||||
await asyncio.get_running_loop().create_task(self.readEvents())
|
await asyncio.create_task(self.readEvents())
|
||||||
|
|
||||||
@retry
|
|
||||||
async def readEvents(self):
|
async def readEvents(self):
|
||||||
while 1:
|
while 1:
|
||||||
rawLine = await self.reader.readline()
|
rawLine = await self.reader.readline()
|
||||||
if not rawLine: break
|
if not rawLine:
|
||||||
|
break
|
||||||
line = rawLine.decode("utf8").strip().split()
|
line = rawLine.decode("utf8").strip().split()
|
||||||
if line[0] == "MENTION":
|
if line[0] == "MENTION":
|
||||||
await self.handleMention(line[1], line[2], line[3])
|
await self.handleMention(line[1], line[2], line[3])
|
||||||
|
|
||||||
|
|
||||||
async def handleMention(self, username: str, channel: str, msgid: str):
|
async def handleMention(self, username: str, channel: str, msgid: str):
|
||||||
session = SessionMaker()
|
session = SessionMaker()
|
||||||
for target in session.query(AlertEndpoints).filter(AlertEndpoints.username == username):
|
for target in session.query(AlertEndpoints).filter(AlertEndpoints.username == username):
|
||||||
|
|
@ -89,19 +78,28 @@ class ErgoClient:
|
||||||
await pusher.send_async(encoded)
|
await pusher.send_async(encoded)
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
|
|
||||||
@retry
|
|
||||||
async def write(self, msg):
|
async def write(self, msg):
|
||||||
|
if self.writer is None:
|
||||||
self.writer.write(msg+b"\n")
|
for _ in range(30):
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
if self.writer:
|
||||||
|
break
|
||||||
|
self.writer.write(msg + b"\n")
|
||||||
await self.writer.drain()
|
await self.writer.drain()
|
||||||
|
|
||||||
async def broadcastAs(self, user, *message):
|
async def broadcastAs(self, user, *message):
|
||||||
await self.write(f"BROADCASTAS {user} {' '.join(message)}".encode("utf8"))
|
await self.write(f"BROADCASTAS {user} {' '.join(message)}".encode("utf8"))
|
||||||
|
|
||||||
|
async def fullyRemoveUser(self, user):
|
||||||
|
await self.write(f"FULLYREMOVE {user}".encode("utf8"))
|
||||||
|
|
||||||
async def broadcastTo(self, user, *message):
|
async def broadcastTo(self, user, *message):
|
||||||
await self.write(f"BROADCASTTO {user} {' '.join(message)}".encode("utf8"))
|
await self.write(f"BROADCASTTO {user} {' '.join(message)}".encode("utf8"))
|
||||||
|
|
||||||
|
|
||||||
ergo = ErgoClient()
|
ergo = ErgoClient()
|
||||||
|
|
||||||
|
privilegedIps = set()
|
||||||
|
for host in config.PRIVILEGED_HOSTS:
|
||||||
|
for addr in [x[-1][0] for x in socket.getaddrinfo("localhost", 0)]:
|
||||||
|
privilegedIps.add(addr)
|
||||||
|
|
|
||||||
|
|
@ -10,11 +10,16 @@ MINIO_EXTERNAL_ADDR = os.getenv("THREEM_MINIO_EXTERNAL_ADDR") or "data.example.x
|
||||||
MINIO_ACCESS_KEY = os.getenv("THREEM_MINIO_ACCESS_KEY") or "access-key-goes-here"
|
MINIO_ACCESS_KEY = os.getenv("THREEM_MINIO_ACCESS_KEY") or "access-key-goes-here"
|
||||||
MINIO_SECRET_KEY = os.getenv("THREEM_MINIO_SECRET_KEY") or "secret-key-goes-here"
|
MINIO_SECRET_KEY = os.getenv("THREEM_MINIO_SECRET_KEY") or "secret-key-goes-here"
|
||||||
DBURL = os.getenv("THREEM_DBURL") or "mysql+mysqldb://ergo:password@localhost/ergo_ext"
|
DBURL = os.getenv("THREEM_DBURL") or "mysql+mysqldb://ergo:password@localhost/ergo_ext"
|
||||||
|
REDIS_ADDR = os.getenv("THREEM_REDIS_ADDR") or "localhost"
|
||||||
|
|
||||||
MAX_FILE_SIZE = 1024*1024*20
|
MAX_FILE_SIZE = 1024*1024*20
|
||||||
MAX_PFP_SIZE = 1024*1024*1.5
|
MAX_PFP_SIZE = 1024*1024*1.5
|
||||||
# It's a 24x24 image, you can fit that in 32k
|
# It's a 24x24 image, you can fit that in 32k
|
||||||
MAX_ICON_SIZE = 1024*32
|
MAX_ICON_SIZE = 1024*32
|
||||||
|
|
||||||
|
# Some endpoints are restricted by accessing IP (bad, I know, but it's what I got for the moment)
|
||||||
|
# This causes a lookup at initialization to limit those endpoints. Mainly for Docker
|
||||||
|
PRIVILEGED_HOSTS = ["localhost"]
|
||||||
|
|
||||||
# Need to figure out how to make this cooperate more
|
# Need to figure out how to make this cooperate more
|
||||||
ALLOWED_DOMAINS = ["*"]
|
ALLOWED_DOMAINS = ["*"]
|
||||||
|
|
@ -13,6 +13,7 @@ dnspython==2.6.1
|
||||||
email_validator==2.1.1
|
email_validator==2.1.1
|
||||||
fastapi==0.111.1
|
fastapi==0.111.1
|
||||||
fastapi-cli==0.0.3
|
fastapi-cli==0.0.3
|
||||||
|
fastapi-utils==0.7.0
|
||||||
frozenlist==1.4.1
|
frozenlist==1.4.1
|
||||||
greenlet==3.0.3
|
greenlet==3.0.3
|
||||||
h11==0.14.0
|
h11==0.14.0
|
||||||
|
|
@ -31,17 +32,21 @@ mdurl==0.1.2
|
||||||
minio==7.1.17
|
minio==7.1.17
|
||||||
more-itertools==10.2.0
|
more-itertools==10.2.0
|
||||||
multidict==6.0.5
|
multidict==6.0.5
|
||||||
|
mypy-extensions==1.0.0
|
||||||
mysqlclient==2.2.4
|
mysqlclient==2.2.4
|
||||||
numpy==1.26.4
|
numpy==1.26.4
|
||||||
orjson==3.10.3
|
orjson==3.10.3
|
||||||
pillow==10.4.0
|
pillow==10.4.0
|
||||||
|
psutil==5.9.8
|
||||||
py-vapid==1.9.1
|
py-vapid==1.9.1
|
||||||
pycparser==2.21
|
pycparser==2.21
|
||||||
pydantic==2.4.2
|
pydantic==2.8.2
|
||||||
pydantic_core==2.10.1
|
pydantic-settings==2.3.4
|
||||||
|
pydantic_core==2.20.1
|
||||||
Pygments==2.18.0
|
Pygments==2.18.0
|
||||||
PyJWT==2.8.0
|
PyJWT==2.8.0
|
||||||
PyMySQL==1.1.0
|
PyMySQL==1.1.0
|
||||||
|
PyNaCl==1.5.0
|
||||||
python-dotenv==1.0.1
|
python-dotenv==1.0.1
|
||||||
python-multipart==0.0.9
|
python-multipart==0.0.9
|
||||||
pywebpush==2.0.0
|
pywebpush==2.0.0
|
||||||
|
|
@ -59,6 +64,7 @@ sse-starlette==2.1.2
|
||||||
starlette==0.37.2
|
starlette==0.37.2
|
||||||
typeguard==4.3.0
|
typeguard==4.3.0
|
||||||
typer==0.12.3
|
typer==0.12.3
|
||||||
|
typing-inspect==0.9.0
|
||||||
typing_extensions==4.12.2
|
typing_extensions==4.12.2
|
||||||
ujson==5.10.0
|
ujson==5.10.0
|
||||||
urllib3==2.2.1
|
urllib3==2.2.1
|
||||||
|
|
|
||||||
|
|
@ -1,8 +0,0 @@
|
||||||
from cef_3M import sql, minioClient
|
|
||||||
|
|
||||||
# This should be run every hour or so to clean up old uploads
|
|
||||||
toDelete = sql.SqlExecuteFetchAll("SELECT *, NOW() FROM uploads WHERE expiry < NOW()")
|
|
||||||
for f in toDelete:
|
|
||||||
minioClient.remove_object("uploads", f["hash"])
|
|
||||||
sql.SqlExecute("DELETE FROM `uploads` WHERE `hash` = %s", f["hash"])
|
|
||||||
print(f"Deleted {len(toDelete)} old files")
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue