add push notification support
This commit is contained in:
parent
6a69c5a34d
commit
1b1dcc3755
12 changed files with 175 additions and 10 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -1,2 +1,3 @@
|
||||||
config.py
|
config.py
|
||||||
alembic.ini
|
alembic.ini
|
||||||
|
keys/*
|
||||||
34
alembic/versions/f73144466860_alert_encryption_table.py
Normal file
34
alembic/versions/f73144466860_alert_encryption_table.py
Normal file
|
|
@ -0,0 +1,34 @@
|
||||||
|
"""alert encryption table
|
||||||
|
|
||||||
|
Revision ID: f73144466860
|
||||||
|
Revises: e6b8e42fa629
|
||||||
|
Create Date: 2024-07-16 06:55:24.078521
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = 'f73144466860'
|
||||||
|
down_revision: Union[str, None] = 'e6b8e42fa629'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"alert_endpoints",
|
||||||
|
sa.Column("id", sa.INT, autoincrement=True, primary_key=True),
|
||||||
|
sa.Column("username", sa.VARCHAR(64), index=True),
|
||||||
|
sa.Column("url", sa.VARCHAR(2048), unique=True),
|
||||||
|
sa.Column("auth", sa.VARCHAR(2048)),
|
||||||
|
sa.Column("p256dh", sa.VARCHAR(2048)),
|
||||||
|
sa.Column("created_at", sa.TIMESTAMP(), server_default=sa.func.now()),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table("alert_endpoints")
|
||||||
|
|
@ -38,7 +38,7 @@ async def upload(file: UploadFile, request: Request):
|
||||||
up = Uploads(hash=sha)
|
up = Uploads(hash=sha)
|
||||||
session.add(up)
|
session.add(up)
|
||||||
session.commit()
|
session.commit()
|
||||||
return {"url": f"https://{config.MINIO_ADDR}/uploads/{sha}/{safeFilename}"}
|
return {"url": f"https://{config.MINIO_EXTERNAL_ADDR}/uploads/{sha}/{safeFilename}"}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
4
cef_3M/endpoints/account.py
Normal file
4
cef_3M/endpoints/account.py
Normal file
|
|
@ -0,0 +1,4 @@
|
||||||
|
from . import router
|
||||||
|
from fastapi import UploadFile, Request, Depends
|
||||||
|
|
||||||
|
from ..auth import JWTBearer
|
||||||
51
cef_3M/endpoints/alerts.py
Normal file
51
cef_3M/endpoints/alerts.py
Normal file
|
|
@ -0,0 +1,51 @@
|
||||||
|
from fastapi import Request, Depends
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
from . import router
|
||||||
|
from .. import JWTBearer
|
||||||
|
from ..sql import SessionMaker, AlertEndpoints
|
||||||
|
from ..util import redis
|
||||||
|
from pydantic import BaseModel, create_model, HttpUrl
|
||||||
|
|
||||||
|
|
||||||
|
class SubscriptionData(BaseModel):
|
||||||
|
endpoint: HttpUrl
|
||||||
|
keys: create_model("keys", auth=(str, ...), p256dh=(str, ...))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/alert/register", dependencies=[Depends(JWTBearer())])
|
||||||
|
async def register(request: Request, subscription: SubscriptionData):
|
||||||
|
session = SessionMaker()
|
||||||
|
check = session.query(AlertEndpoints).filter(AlertEndpoints.url == str(subscription.endpoint))
|
||||||
|
if check.first() is not None:
|
||||||
|
return JSONResponse({
|
||||||
|
"success": True
|
||||||
|
})
|
||||||
|
info = AlertEndpoints(username=request.state.jwt["account"], url=str(subscription.endpoint), auth=subscription.keys.auth, p256dh=subscription.keys.p256dh)
|
||||||
|
|
||||||
|
session.merge(info)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return JSONResponse({
|
||||||
|
"success": True
|
||||||
|
})
|
||||||
|
|
||||||
|
@router.post("/alert/unregister", dependencies=[Depends(JWTBearer())])
|
||||||
|
async def unregister(request: Request):
|
||||||
|
session = SessionMaker()
|
||||||
|
body = await request.json()
|
||||||
|
session.query(AlertEndpoints).filter(AlertEndpoints.url == body.get("url", ""), AlertEndpoints.username == request.state.jwt["account"]).delete()
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return JSONResponse({
|
||||||
|
"success": True
|
||||||
|
})
|
||||||
|
|
||||||
|
@router.post("/alert/clear", dependencies=[Depends(JWTBearer())])
|
||||||
|
async def clear(request: Request):
|
||||||
|
session = SessionMaker()
|
||||||
|
session.query(AlertEndpoints).filter(AlertEndpoints.username == request.state.jwt["account"]).delete()
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return JSONResponse({
|
||||||
|
"success": True
|
||||||
|
})
|
||||||
|
|
@ -23,7 +23,7 @@ def pathParts(path):
|
||||||
async def mediamtxChannelStreams(request: Request, channel: str):
|
async def mediamtxChannelStreams(request: Request, channel: str):
|
||||||
inChannel = request.state.jwt.get("channel", "").lower() == "#" + channel.lower()
|
inChannel = request.state.jwt.get("channel", "").lower() == "#" + channel.lower()
|
||||||
results = []
|
results = []
|
||||||
for result in redis.scan_iter(f"stream #{channel} *"):
|
async for result in redis.scan_iter(f"stream #{channel} *"):
|
||||||
_, channel, user, token = result.decode("utf8").split()
|
_, channel, user, token = result.decode("utf8").split()
|
||||||
if inChannel or token == "public":
|
if inChannel or token == "public":
|
||||||
results.append({
|
results.append({
|
||||||
|
|
@ -74,7 +74,7 @@ async def mediamtxAdd(request: Request):
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
path = body["env"]["MTX_PATH"].split("/")
|
path = body["env"]["MTX_PATH"].split("/")
|
||||||
parts = [x for x in pathParts(path) if x]
|
parts = [x for x in pathParts(path) if x]
|
||||||
redis.set("stream " + " ".join(parts), parts[2])
|
await redis.set("stream " + " ".join(parts), parts[2])
|
||||||
if len(parts) == 3:
|
if len(parts) == 3:
|
||||||
await ergo.broadcastTo(parts[0], "STREAMSTART", parts[0], parts[1], parts[2])
|
await ergo.broadcastTo(parts[0], "STREAMSTART", parts[0], parts[1], parts[2])
|
||||||
|
|
||||||
|
|
@ -87,7 +87,7 @@ async def mediamtxDelete(request: Request):
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
path = body["env"]["MTX_PATH"].split("/")
|
path = body["env"]["MTX_PATH"].split("/")
|
||||||
parts = [x for x in pathParts(path) if x]
|
parts = [x for x in pathParts(path) if x]
|
||||||
redis.delete("stream " + " ".join(parts))
|
await redis.delete("stream " + " ".join(parts))
|
||||||
if len(parts) == 3:
|
if len(parts) == 3:
|
||||||
await ergo.broadcastTo(parts[0], "STREAMEND", parts[0], parts[1], parts[2])
|
await ergo.broadcastTo(parts[0], "STREAMEND", parts[0], parts[1], parts[2])
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,11 +6,22 @@ import configparser
|
||||||
alembic = configparser.ConfigParser()
|
alembic = configparser.ConfigParser()
|
||||||
alembic.read("alembic.ini")
|
alembic.read("alembic.ini")
|
||||||
|
|
||||||
|
|
||||||
engine = create_engine(
|
engine = create_engine(
|
||||||
alembic.get("alembic", "sqlalchemy.url"),
|
alembic.get("alembic", "sqlalchemy.url"),
|
||||||
pool_pre_ping=True,
|
pool_pre_ping=True,
|
||||||
pool_recycle=3600
|
pool_recycle=1800
|
||||||
)
|
)
|
||||||
|
|
||||||
SessionMaker = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
ergoEngine = create_engine(
|
||||||
|
alembic.get("alembic", "sqlalchemy.url"),
|
||||||
|
pool_pre_ping=True,
|
||||||
|
pool_recycle=1800
|
||||||
|
)
|
||||||
|
|
||||||
|
SessionMaker = sessionmaker(autocommit=False, autoflush=False, bind=engine, )
|
||||||
|
|
||||||
|
def ergoQueryFetchOne(q: str, **kwargs):
|
||||||
|
with ergoEngine.connect() as connection:
|
||||||
|
connection.execute(text("use ergo"))
|
||||||
|
return connection.execute(text(q), kwargs).fetchone()
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from sqlalchemy import CHAR, String, TIMESTAMP, text
|
from sqlalchemy import CHAR, Index, String, TIMESTAMP, text
|
||||||
from sqlalchemy.dialects.mysql import TINYINT
|
from sqlalchemy.dialects.mysql import INTEGER, TINYINT
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
|
|
@ -9,6 +9,21 @@ class Base(DeclarativeBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AlertEndpoints(Base):
|
||||||
|
__tablename__ = 'alert_endpoints'
|
||||||
|
__table_args__ = (
|
||||||
|
Index('ix_alert_endpoints_username', 'username'),
|
||||||
|
Index('url', 'url', unique=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(INTEGER(11), primary_key=True)
|
||||||
|
username: Mapped[Optional[str]] = mapped_column(String(64))
|
||||||
|
url: Mapped[Optional[str]] = mapped_column(String(2048))
|
||||||
|
auth: Mapped[Optional[str]] = mapped_column(String(2048))
|
||||||
|
p256dh: Mapped[Optional[str]] = mapped_column(String(2048))
|
||||||
|
created_at: Mapped[Optional[datetime.datetime]] = mapped_column(TIMESTAMP, server_default=text('current_timestamp()'))
|
||||||
|
|
||||||
|
|
||||||
class Uploads(Base):
|
class Uploads(Base):
|
||||||
__tablename__ = 'uploads'
|
__tablename__ = 'uploads'
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,16 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import json
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
import MySQLdb
|
||||||
|
|
||||||
import config
|
import config
|
||||||
from minio import Minio
|
from minio import Minio
|
||||||
from redis import Redis
|
from redis.asyncio import Redis
|
||||||
|
from pywebpush import WebPusher
|
||||||
|
|
||||||
|
from .sql import SessionMaker, AlertEndpoints, ergoQueryFetchOne
|
||||||
|
|
||||||
from fastapi import UploadFile
|
from fastapi import UploadFile
|
||||||
|
|
||||||
|
|
@ -49,7 +56,38 @@ class ErgoClient:
|
||||||
@retry
|
@retry
|
||||||
async def init(self):
|
async def init(self):
|
||||||
self.reader, self.writer = await asyncio.open_connection(config.ERGO_ADDR, config.ERGO_PORT)
|
self.reader, self.writer = await asyncio.open_connection(config.ERGO_ADDR, config.ERGO_PORT)
|
||||||
|
await asyncio.get_running_loop().create_task(self.readEvents())
|
||||||
|
|
||||||
|
@retry
|
||||||
|
async def readEvents(self):
|
||||||
|
while 1:
|
||||||
|
rawLine = await self.reader.readline()
|
||||||
|
if not rawLine: break
|
||||||
|
line = rawLine.decode("utf8").strip().split()
|
||||||
|
if line[0] == "MENTION":
|
||||||
|
await self.handleMention(line[1], line[2], line[3])
|
||||||
|
|
||||||
|
|
||||||
|
async def handleMention(self, username: str, channel: str, msgid: str):
|
||||||
|
session = SessionMaker()
|
||||||
|
for target in session.query(AlertEndpoints).filter(AlertEndpoints.username == username):
|
||||||
|
pusher = WebPusher({
|
||||||
|
"endpoint": target.url,
|
||||||
|
"keys": {
|
||||||
|
"auth": target.auth,
|
||||||
|
"p256dh": target.p256dh
|
||||||
|
}
|
||||||
|
})
|
||||||
|
messageQuery = ergoQueryFetchOne("SELECT `data` FROM `history` WHERE `msgid` = :id", id=int(msgid)).data
|
||||||
|
message = json.loads(messageQuery.decode("utf8"))
|
||||||
|
encoded = json.dumps({
|
||||||
|
"channel": channel,
|
||||||
|
"from": message["AccountName"],
|
||||||
|
"content": message["Message"]["Message"]
|
||||||
|
}).encode("utf8")
|
||||||
|
|
||||||
|
await pusher.send_async(encoded)
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
|
||||||
@retry
|
@retry
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import os
|
import os
|
||||||
SECRETKEY = os.path.join("secrets", "pubkey.pem")
|
SECRETKEY = os.path.join("keys", "pubkey.pem")
|
||||||
|
|
||||||
# CEF-specific port (grumble communicates over it as well)
|
# CEF-specific port (grumble communicates over it as well)
|
||||||
ERGO_ADDR = "127.0.0.1"
|
ERGO_ADDR = "127.0.0.1"
|
||||||
|
|
|
||||||
9
main.py
9
main.py
|
|
@ -1,3 +1,12 @@
|
||||||
|
import subprocess
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import cef_3M
|
import cef_3M
|
||||||
|
import os
|
||||||
|
|
||||||
|
if not os.path.exists(os.path.join("keys", "vapid.pem")):
|
||||||
|
print("Generating VAPID key")
|
||||||
|
subprocess.run("openssl ecparam -name prime256v1 -genkey -noout -out keys/vapid.pem".split())
|
||||||
|
print("key generated")
|
||||||
|
|
||||||
uvicorn.run("cef_3M:app", port=8001, reload=True)
|
uvicorn.run("cef_3M:app", port=8001, reload=True)
|
||||||
|
|
|
||||||
|
|
@ -52,3 +52,5 @@ uvicorn==0.23.2
|
||||||
uvloop==0.19.0
|
uvloop==0.19.0
|
||||||
watchfiles==0.21.0
|
watchfiles==0.21.0
|
||||||
websockets==12.0
|
websockets==12.0
|
||||||
|
redis==5.0.6
|
||||||
|
pywebpush==2.0.0
|
||||||
Loading…
Add table
Add a link
Reference in a new issue