From 1b1dcc375589621c47f44774fd091d267a3ca684 Mon Sep 17 00:00:00 2001 From: CEF Server Date: Wed, 17 Jul 2024 14:34:01 +0000 Subject: [PATCH] add push notification support --- .gitignore | 1 + .../f73144466860_alert_encryption_table.py | 34 +++++++++++++ cef_3M/__init__.py | 2 +- cef_3M/endpoints/account.py | 4 ++ cef_3M/endpoints/alerts.py | 51 +++++++++++++++++++ cef_3M/endpoints/mediamtx.py | 6 +-- cef_3M/sql.py | 15 +++++- cef_3M/sql_generated.py | 19 ++++++- cef_3M/util.py | 40 ++++++++++++++- config.example.py | 2 +- main.py | 9 ++++ requirements.txt | 2 + 12 files changed, 175 insertions(+), 10 deletions(-) create mode 100644 alembic/versions/f73144466860_alert_encryption_table.py create mode 100644 cef_3M/endpoints/account.py create mode 100644 cef_3M/endpoints/alerts.py diff --git a/.gitignore b/.gitignore index bb54b34..67c49c3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ config.py alembic.ini +keys/* \ No newline at end of file diff --git a/alembic/versions/f73144466860_alert_encryption_table.py b/alembic/versions/f73144466860_alert_encryption_table.py new file mode 100644 index 0000000..17d66bf --- /dev/null +++ b/alembic/versions/f73144466860_alert_encryption_table.py @@ -0,0 +1,34 @@ +"""alert encryption table + +Revision ID: f73144466860 +Revises: e6b8e42fa629 +Create Date: 2024-07-16 06:55:24.078521 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'f73144466860' +down_revision: Union[str, None] = 'e6b8e42fa629' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "alert_endpoints", + sa.Column("id", sa.INT, autoincrement=True, primary_key=True), + sa.Column("username", sa.VARCHAR(64), index=True), + sa.Column("url", sa.VARCHAR(2048), unique=True), + sa.Column("auth", sa.VARCHAR(2048)), + sa.Column("p256dh", sa.VARCHAR(2048)), + sa.Column("created_at", sa.TIMESTAMP(), server_default=sa.func.now()), + ) + + +def downgrade() -> None: + op.drop_table("alert_endpoints") diff --git a/cef_3M/__init__.py b/cef_3M/__init__.py index e482565..066231b 100644 --- a/cef_3M/__init__.py +++ b/cef_3M/__init__.py @@ -38,7 +38,7 @@ async def upload(file: UploadFile, request: Request): up = Uploads(hash=sha) session.add(up) session.commit() - return {"url": f"https://{config.MINIO_ADDR}/uploads/{sha}/{safeFilename}"} + return {"url": f"https://{config.MINIO_EXTERNAL_ADDR}/uploads/{sha}/{safeFilename}"} diff --git a/cef_3M/endpoints/account.py b/cef_3M/endpoints/account.py new file mode 100644 index 0000000..3f146a7 --- /dev/null +++ b/cef_3M/endpoints/account.py @@ -0,0 +1,4 @@ +from . import router +from fastapi import UploadFile, Request, Depends + +from ..auth import JWTBearer diff --git a/cef_3M/endpoints/alerts.py b/cef_3M/endpoints/alerts.py new file mode 100644 index 0000000..804a970 --- /dev/null +++ b/cef_3M/endpoints/alerts.py @@ -0,0 +1,51 @@ +from fastapi import Request, Depends +from starlette.responses import JSONResponse +from . import router +from .. import JWTBearer +from ..sql import SessionMaker, AlertEndpoints +from ..util import redis +from pydantic import BaseModel, create_model, HttpUrl + + +class SubscriptionData(BaseModel): + endpoint: HttpUrl + keys: create_model("keys", auth=(str, ...), p256dh=(str, ...)) + + +@router.post("/alert/register", dependencies=[Depends(JWTBearer())]) +async def register(request: Request, subscription: SubscriptionData): + session = SessionMaker() + check = session.query(AlertEndpoints).filter(AlertEndpoints.url == str(subscription.endpoint)) + if check.first() is not None: + return JSONResponse({ + "success": True + }) + info = AlertEndpoints(username=request.state.jwt["account"], url=str(subscription.endpoint), auth=subscription.keys.auth, p256dh=subscription.keys.p256dh) + + session.merge(info) + session.commit() + + return JSONResponse({ + "success": True + }) + +@router.post("/alert/unregister", dependencies=[Depends(JWTBearer())]) +async def unregister(request: Request): + session = SessionMaker() + body = await request.json() + session.query(AlertEndpoints).filter(AlertEndpoints.url == body.get("url", ""), AlertEndpoints.username == request.state.jwt["account"]).delete() + session.commit() + + return JSONResponse({ + "success": True + }) + +@router.post("/alert/clear", dependencies=[Depends(JWTBearer())]) +async def clear(request: Request): + session = SessionMaker() + session.query(AlertEndpoints).filter(AlertEndpoints.username == request.state.jwt["account"]).delete() + session.commit() + + return JSONResponse({ + "success": True + }) diff --git a/cef_3M/endpoints/mediamtx.py b/cef_3M/endpoints/mediamtx.py index fcaea78..bc2dd05 100644 --- a/cef_3M/endpoints/mediamtx.py +++ b/cef_3M/endpoints/mediamtx.py @@ -23,7 +23,7 @@ def pathParts(path): async def mediamtxChannelStreams(request: Request, channel: str): inChannel = request.state.jwt.get("channel", "").lower() == "#" + channel.lower() results = [] - for result in redis.scan_iter(f"stream #{channel} *"): + async for result in redis.scan_iter(f"stream #{channel} *"): _, channel, user, token = result.decode("utf8").split() if inChannel or token == "public": results.append({ @@ -74,7 +74,7 @@ async def mediamtxAdd(request: Request): body = await request.json() path = body["env"]["MTX_PATH"].split("/") parts = [x for x in pathParts(path) if x] - redis.set("stream " + " ".join(parts), parts[2]) + await redis.set("stream " + " ".join(parts), parts[2]) if len(parts) == 3: await ergo.broadcastTo(parts[0], "STREAMSTART", parts[0], parts[1], parts[2]) @@ -87,7 +87,7 @@ async def mediamtxDelete(request: Request): body = await request.json() path = body["env"]["MTX_PATH"].split("/") parts = [x for x in pathParts(path) if x] - redis.delete("stream " + " ".join(parts)) + await redis.delete("stream " + " ".join(parts)) if len(parts) == 3: await ergo.broadcastTo(parts[0], "STREAMEND", parts[0], parts[1], parts[2]) diff --git a/cef_3M/sql.py b/cef_3M/sql.py index e9ddba8..194e2c5 100644 --- a/cef_3M/sql.py +++ b/cef_3M/sql.py @@ -6,11 +6,22 @@ import configparser alembic = configparser.ConfigParser() alembic.read("alembic.ini") + engine = create_engine( alembic.get("alembic", "sqlalchemy.url"), pool_pre_ping=True, - pool_recycle=3600 + pool_recycle=1800 ) -SessionMaker = sessionmaker(autocommit=False, autoflush=False, bind=engine) +ergoEngine = create_engine( + alembic.get("alembic", "sqlalchemy.url"), + pool_pre_ping=True, + pool_recycle=1800 +) +SessionMaker = sessionmaker(autocommit=False, autoflush=False, bind=engine, ) + +def ergoQueryFetchOne(q: str, **kwargs): + with ergoEngine.connect() as connection: + connection.execute(text("use ergo")) + return connection.execute(text(q), kwargs).fetchone() \ No newline at end of file diff --git a/cef_3M/sql_generated.py b/cef_3M/sql_generated.py index 9a80366..9e92de3 100644 --- a/cef_3M/sql_generated.py +++ b/cef_3M/sql_generated.py @@ -1,7 +1,7 @@ from typing import Optional -from sqlalchemy import CHAR, String, TIMESTAMP, text -from sqlalchemy.dialects.mysql import TINYINT +from sqlalchemy import CHAR, Index, String, TIMESTAMP, text +from sqlalchemy.dialects.mysql import INTEGER, TINYINT from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column import datetime @@ -9,6 +9,21 @@ class Base(DeclarativeBase): pass +class AlertEndpoints(Base): + __tablename__ = 'alert_endpoints' + __table_args__ = ( + Index('ix_alert_endpoints_username', 'username'), + Index('url', 'url', unique=True) + ) + + id: Mapped[int] = mapped_column(INTEGER(11), primary_key=True) + username: Mapped[Optional[str]] = mapped_column(String(64)) + url: Mapped[Optional[str]] = mapped_column(String(2048)) + auth: Mapped[Optional[str]] = mapped_column(String(2048)) + p256dh: Mapped[Optional[str]] = mapped_column(String(2048)) + created_at: Mapped[Optional[datetime.datetime]] = mapped_column(TIMESTAMP, server_default=text('current_timestamp()')) + + class Uploads(Base): __tablename__ = 'uploads' diff --git a/cef_3M/util.py b/cef_3M/util.py index de939a4..608aa1b 100644 --- a/cef_3M/util.py +++ b/cef_3M/util.py @@ -1,9 +1,16 @@ import asyncio import hashlib +import json import re + +import MySQLdb + import config from minio import Minio -from redis import Redis +from redis.asyncio import Redis +from pywebpush import WebPusher + +from .sql import SessionMaker, AlertEndpoints, ergoQueryFetchOne from fastapi import UploadFile @@ -49,7 +56,38 @@ class ErgoClient: @retry async def init(self): self.reader, self.writer = await asyncio.open_connection(config.ERGO_ADDR, config.ERGO_PORT) + await asyncio.get_running_loop().create_task(self.readEvents()) + @retry + async def readEvents(self): + while 1: + rawLine = await self.reader.readline() + if not rawLine: break + line = rawLine.decode("utf8").strip().split() + if line[0] == "MENTION": + await self.handleMention(line[1], line[2], line[3]) + + + async def handleMention(self, username: str, channel: str, msgid: str): + session = SessionMaker() + for target in session.query(AlertEndpoints).filter(AlertEndpoints.username == username): + pusher = WebPusher({ + "endpoint": target.url, + "keys": { + "auth": target.auth, + "p256dh": target.p256dh + } + }) + messageQuery = ergoQueryFetchOne("SELECT `data` FROM `history` WHERE `msgid` = :id", id=int(msgid)).data + message = json.loads(messageQuery.decode("utf8")) + encoded = json.dumps({ + "channel": channel, + "from": message["AccountName"], + "content": message["Message"]["Message"] + }).encode("utf8") + + await pusher.send_async(encoded) + session.close() @retry diff --git a/config.example.py b/config.example.py index 89abc50..70618ee 100644 --- a/config.example.py +++ b/config.example.py @@ -1,5 +1,5 @@ import os -SECRETKEY = os.path.join("secrets", "pubkey.pem") +SECRETKEY = os.path.join("keys", "pubkey.pem") # CEF-specific port (grumble communicates over it as well) ERGO_ADDR = "127.0.0.1" diff --git a/main.py b/main.py index e0f0c77..6cc3470 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,12 @@ +import subprocess + import uvicorn import cef_3M +import os + +if not os.path.exists(os.path.join("keys", "vapid.pem")): + print("Generating VAPID key") + subprocess.run("openssl ecparam -name prime256v1 -genkey -noout -out keys/vapid.pem".split()) + print("key generated") + uvicorn.run("cef_3M:app", port=8001, reload=True) diff --git a/requirements.txt b/requirements.txt index 0490d75..3df5933 100644 --- a/requirements.txt +++ b/requirements.txt @@ -52,3 +52,5 @@ uvicorn==0.23.2 uvloop==0.19.0 watchfiles==0.21.0 websockets==12.0 +redis==5.0.6 +pywebpush==2.0.0 \ No newline at end of file