diff --git a/alembic/versions/173d66d27d78_remember_me_table.py b/alembic/versions/173d66d27d78_remember_me_table.py new file mode 100644 index 0000000..ea2f7f4 --- /dev/null +++ b/alembic/versions/173d66d27d78_remember_me_table.py @@ -0,0 +1,32 @@ +"""remember me table + +Revision ID: 173d66d27d78 +Revises: 3c43e544e939 +Create Date: 2024-09-16 00:04:58.835593 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '173d66d27d78' +down_revision: Union[str, None] = '3c43e544e939' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "sessions", + sa.Column("id", sa.INT, autoincrement=True, primary_key=True), + sa.Column("username", sa.VARCHAR(64), sa.ForeignKey("users.username", ondelete="CASCADE"), index=True), + sa.Column("hash", sa.VARCHAR(128)), + sa.Column("expiry", sa.TIMESTAMP(), nullable=True, default=None), + ) + + +def downgrade() -> None: + op.drop_table("sessions") diff --git a/alembic/versions/aa17ed273170_invited_by_and_check_tokens.py b/alembic/versions/aa17ed273170_invited_by_and_check_tokens.py new file mode 100644 index 0000000..ee96ad8 --- /dev/null +++ b/alembic/versions/aa17ed273170_invited_by_and_check_tokens.py @@ -0,0 +1,39 @@ +"""invited by and check tokens + +Revision ID: aa17ed273170 +Revises: 173d66d27d78 +Create Date: 2024-11-24 03:21:32.324284 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'aa17ed273170' +down_revision: Union[str, None] = '173d66d27d78' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "tokens", + sa.Column("id", sa.INTEGER(), primary_key=True, autoincrement=True), + sa.Column("username", sa.VARCHAR(64), index=True), + sa.Column("hash", sa.VARCHAR(128)), + sa.Column("name", sa.VARCHAR(128)), + sa.Column("created_at", sa.TIMESTAMP(), server_default=sa.func.now()), + ) + op.create_foreign_key("token_username_fk", "tokens", "users", + ["username"], ["username"], ondelete="CASCADE") + + op.add_column("users", sa.Column("invited_by", sa.VARCHAR(64), default=None)) + + + +def downgrade() -> None: + op.drop_column("users", "invited_by") + op.drop_table("tokens") \ No newline at end of file diff --git a/cef_3M/endpoints/account.py b/cef_3M/endpoints/account.py index 239e872..bd3d432 100644 --- a/cef_3M/endpoints/account.py +++ b/cef_3M/endpoints/account.py @@ -1,17 +1,49 @@ +import base64 import random +from typing import Annotated -from pydantic import BaseModel +from pydantic import BaseModel, StringConstraints from . import router from starlette.responses import JSONResponse from fastapi import Request, HTTPException, Depends -from ..sql import SessionMaker, Users +from ..sql import SessionMaker, Users, Sessions +from ..sql_generated import Tokens from ..util import privilegedIps from ..auth import JWTBearer import nacl.pwhash +import nacl.utils import string +class PasswordChange(BaseModel): + currentPassword: str + newPassword: str + newPasswordAgain: str + + +class RememberMe(BaseModel): + username: str + password: str + +class TokenVerify(BaseModel): + username: str + token: str + +class TokenCreate(BaseModel): + label: Annotated[str, StringConstraints(max_length=64)] + +class TokenDelete(BaseModel): + id: int + +def passwordVerify(hash: str, password: str) -> bool: + try: + nacl.pwhash.scrypt.verify(hash.encode("utf8"), password.encode("utf8")) + return True + except: + return False + + @router.get("/account/exists/{name}") async def exists(name: str): with SessionMaker() as session: @@ -29,10 +61,27 @@ async def exists(name: str): }) -class PasswordChange(BaseModel): - currentPassword: str - newPassword: str - newPasswordAgain: str +@router.post("/account/remember") +async def rememberLogin(request: Request, loginData: RememberMe): + with SessionMaker() as session: + check = session.query(Users).filter(Users.username == str(loginData.username)) + first: Users = check.first() + if not first or not passwordVerify(first.password, loginData.password): + return JSONResponse({ + "success": False, + "error": "Incorrect password" + }) + token = base64.b64encode(nacl.utils.random(32)) + + sess = Sessions(username=first.username, hash=nacl.pwhash.scrypt.str(token)) + session.add(sess) + session.commit() + + return JSONResponse({ + "success": True, + "token": token.decode("utf8") + }) + @router.get("/account/invite", dependencies=[Depends(JWTBearer())]) async def getInvite(request: Request): @@ -44,6 +93,7 @@ async def getInvite(request: Request): "code": user.invite_code }) + @router.post("/account/invite/regenerate", dependencies=[Depends(JWTBearer())]) async def regenInvite(request: Request): username = request.state.jwt["account"] @@ -58,6 +108,7 @@ async def regenInvite(request: Request): "code": code }) + @router.post("/account/password", dependencies=[Depends(JWTBearer(False))]) async def changePassword(request: Request, passwordData: PasswordChange): if passwordData.newPassword != passwordData.newPasswordAgain: @@ -69,22 +120,82 @@ async def changePassword(request: Request, passwordData: PasswordChange): with SessionMaker() as session: user = session.query(Users).filter(Users.username == username).first() - bPassOld = passwordData.currentPassword.encode("utf8") - - try: - nacl.pwhash.scrypt.verify(user.password.encode("utf8"), bPassOld) - except: + if not passwordVerify(user.password, passwordData.currentPassword): raise HTTPException(status_code=403, detail="Invalid original password") bPass = passwordData.newPassword.encode("utf8") user.password = nacl.pwhash.scrypt.str(bPass) user.temporary = False + + # clear sessions and tokens + session.query(Sessions).filter(Sessions.username == user.username).delete() + session.query(Tokens).filter(Tokens.username == user.username).delete() + session.commit() return JSONResponse({ "success": True }) +@router.post("/account/tokenCheck") +async def tokenVerify(request: Request, tokenData: TokenVerify): + with SessionMaker() as session: + tokens = session.query(Tokens).filter(Tokens.username == str(tokenData.username)) + for token in tokens.all(): + if passwordVerify(token.hash, tokenData.token): + return JSONResponse({ + "success": True, + }) + return JSONResponse({ + "success": False, + }) + +@router.get("/account/tokens", dependencies=[Depends(JWTBearer())]) +async def getTokens(request: Request): + username = request.state.jwt["account"] + with SessionMaker() as session: + tokens = session.query(Tokens).filter(Tokens.username == str(username)) + + return JSONResponse({ + "tokens": [ + { + "id": x.id, + "name": x.name, + "created_at": x.created_at.isoformat() + } for x in tokens.all() + ] + }) + +@router.post("/account/token/create", dependencies=[Depends(JWTBearer())]) +async def createToken(request: Request, info: TokenCreate): + username = request.state.jwt["account"] + with SessionMaker() as session: + value = f"{username}-{base64.b64encode(nacl.utils.random(24)).decode('utf8')}" + + token = Tokens(username=username, hash=nacl.pwhash.scrypt.str(value.encode("utf8")), name=info.label) + session.add(token) + session.commit() + + return JSONResponse({ + "token": { + "id": token.id, + "name": token.name, + "created_at": token.created_at.isoformat(), + "token": value + } + }) + +@router.post("/account/token/delete", dependencies=[Depends(JWTBearer())]) +async def deleteToken(request: Request, info: TokenDelete): + username = request.state.jwt["account"] + with SessionMaker() as session: + session.query(Tokens).filter(Tokens.username == username, Tokens.id == info.id).delete() + session.commit() + + return JSONResponse({ + "success": True + }) + @router.post("/account/verify", include_in_schema=False) async def verify(request: Request): if request.client.host not in privilegedIps: @@ -95,12 +206,19 @@ async def verify(request: Request): check = session.query(Users).filter(Users.username == str(body["accountName"])) first = check.first() if first: - try: - nacl.pwhash.scrypt.verify(first.password.encode("utf8"), bPass) + # Not too happy with this approach but it should work + sessions = session.query(Sessions).filter(Users.username == str(body["accountName"])) + for sess in sessions.all(): + if passwordVerify(sess.hash, body.get("passphrase", "")): + return JSONResponse({ + "success": True, + }) + + if passwordVerify(first.password, body.get("passphrase", "")): return JSONResponse({ "success": True, }) - except: + else: return JSONResponse({ "success": False, "error": "Incorrect password" @@ -123,7 +241,6 @@ async def verify(request: Request): "success": False, "error": "Bad invite code" }) - print("invite code", code, "password", password) account = Users(username=body["accountName"], password=nacl.pwhash.scrypt.str(password), temporary=True) session.add(account) session.commit() diff --git a/cef_3M/sql_generated.py b/cef_3M/sql_generated.py index 7976f43..4380306 100644 --- a/cef_3M/sql_generated.py +++ b/cef_3M/sql_generated.py @@ -1,8 +1,8 @@ -from typing import Optional +from typing import List, Optional -from sqlalchemy import CHAR, Index, String, TIMESTAMP, text +from sqlalchemy import CHAR, ForeignKeyConstraint, Index, String, TIMESTAMP, text from sqlalchemy.dialects.mysql import INTEGER, TINYINT -from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship import datetime class Base(DeclarativeBase): @@ -39,3 +39,38 @@ class Users(Base): created_at: Mapped[Optional[datetime.datetime]] = mapped_column(TIMESTAMP, server_default=text('current_timestamp()')) temporary: Mapped[Optional[int]] = mapped_column(TINYINT(1), server_default=text('1')) invite_code: Mapped[Optional[str]] = mapped_column(String(32)) + invited_by: Mapped[Optional[str]] = mapped_column(String(64)) + + sessions: Mapped[List['Sessions']] = relationship('Sessions', back_populates='users') + tokens: Mapped[List['Tokens']] = relationship('Tokens', back_populates='users') + + +class Sessions(Base): + __tablename__ = 'sessions' + __table_args__ = ( + ForeignKeyConstraint(['username'], ['users.username'], ondelete='CASCADE', name='sessions_ibfk_1'), + Index('ix_sessions_username', 'username') + ) + + id: Mapped[int] = mapped_column(INTEGER(11), primary_key=True) + username: Mapped[Optional[str]] = mapped_column(String(64)) + hash: Mapped[Optional[str]] = mapped_column(String(128)) + expiry: Mapped[Optional[datetime.datetime]] = mapped_column(TIMESTAMP) + + users: Mapped['Users'] = relationship('Users', back_populates='sessions') + + +class Tokens(Base): + __tablename__ = 'tokens' + __table_args__ = ( + ForeignKeyConstraint(['username'], ['users.username'], ondelete='CASCADE', name='token_username_fk'), + Index('ix_tokens_username', 'username') + ) + + id: Mapped[int] = mapped_column(INTEGER(11), primary_key=True) + username: Mapped[Optional[str]] = mapped_column(String(64)) + hash: Mapped[Optional[str]] = mapped_column(String(128)) + name: Mapped[Optional[str]] = mapped_column(String(128)) + created_at: Mapped[Optional[datetime.datetime]] = mapped_column(TIMESTAMP, server_default=text('current_timestamp()')) + + users: Mapped['Users'] = relationship('Users', back_populates='tokens') diff --git a/run.sh b/run.sh index f421a00..3ec0310 100755 --- a/run.sh +++ b/run.sh @@ -1,3 +1,3 @@ #!/bin/sh alembic upgrade head -uvicorn cef_3M:app --port 8001 --host 0.0.0.0 +uvicorn cef_3M:app --port 8001 --host 0.0.0.0 --reload