remembering logins + tokens

This commit is contained in:
CEF Server 2024-11-28 02:54:10 +00:00
parent 83404812d0
commit 6871aa2449
5 changed files with 242 additions and 19 deletions

View file

@ -0,0 +1,32 @@
"""remember me table
Revision ID: 173d66d27d78
Revises: 3c43e544e939
Create Date: 2024-09-16 00:04:58.835593
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '173d66d27d78'
down_revision: Union[str, None] = '3c43e544e939'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"sessions",
sa.Column("id", sa.INT, autoincrement=True, primary_key=True),
sa.Column("username", sa.VARCHAR(64), sa.ForeignKey("users.username", ondelete="CASCADE"), index=True),
sa.Column("hash", sa.VARCHAR(128)),
sa.Column("expiry", sa.TIMESTAMP(), nullable=True, default=None),
)
def downgrade() -> None:
op.drop_table("sessions")

View file

@ -0,0 +1,39 @@
"""invited by and check tokens
Revision ID: aa17ed273170
Revises: 173d66d27d78
Create Date: 2024-11-24 03:21:32.324284
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'aa17ed273170'
down_revision: Union[str, None] = '173d66d27d78'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"tokens",
sa.Column("id", sa.INTEGER(), primary_key=True, autoincrement=True),
sa.Column("username", sa.VARCHAR(64), index=True),
sa.Column("hash", sa.VARCHAR(128)),
sa.Column("name", sa.VARCHAR(128)),
sa.Column("created_at", sa.TIMESTAMP(), server_default=sa.func.now()),
)
op.create_foreign_key("token_username_fk", "tokens", "users",
["username"], ["username"], ondelete="CASCADE")
op.add_column("users", sa.Column("invited_by", sa.VARCHAR(64), default=None))
def downgrade() -> None:
op.drop_column("users", "invited_by")
op.drop_table("tokens")

View file

@ -1,17 +1,49 @@
import base64
import random
from typing import Annotated
from pydantic import BaseModel
from pydantic import BaseModel, StringConstraints
from . import router
from starlette.responses import JSONResponse
from fastapi import Request, HTTPException, Depends
from ..sql import SessionMaker, Users
from ..sql import SessionMaker, Users, Sessions
from ..sql_generated import Tokens
from ..util import privilegedIps
from ..auth import JWTBearer
import nacl.pwhash
import nacl.utils
import string
class PasswordChange(BaseModel):
currentPassword: str
newPassword: str
newPasswordAgain: str
class RememberMe(BaseModel):
username: str
password: str
class TokenVerify(BaseModel):
username: str
token: str
class TokenCreate(BaseModel):
label: Annotated[str, StringConstraints(max_length=64)]
class TokenDelete(BaseModel):
id: int
def passwordVerify(hash: str, password: str) -> bool:
try:
nacl.pwhash.scrypt.verify(hash.encode("utf8"), password.encode("utf8"))
return True
except:
return False
@router.get("/account/exists/{name}")
async def exists(name: str):
with SessionMaker() as session:
@ -29,10 +61,27 @@ async def exists(name: str):
})
class PasswordChange(BaseModel):
currentPassword: str
newPassword: str
newPasswordAgain: str
@router.post("/account/remember")
async def rememberLogin(request: Request, loginData: RememberMe):
with SessionMaker() as session:
check = session.query(Users).filter(Users.username == str(loginData.username))
first: Users = check.first()
if not first or not passwordVerify(first.password, loginData.password):
return JSONResponse({
"success": False,
"error": "Incorrect password"
})
token = base64.b64encode(nacl.utils.random(32))
sess = Sessions(username=first.username, hash=nacl.pwhash.scrypt.str(token))
session.add(sess)
session.commit()
return JSONResponse({
"success": True,
"token": token.decode("utf8")
})
@router.get("/account/invite", dependencies=[Depends(JWTBearer())])
async def getInvite(request: Request):
@ -44,6 +93,7 @@ async def getInvite(request: Request):
"code": user.invite_code
})
@router.post("/account/invite/regenerate", dependencies=[Depends(JWTBearer())])
async def regenInvite(request: Request):
username = request.state.jwt["account"]
@ -58,6 +108,7 @@ async def regenInvite(request: Request):
"code": code
})
@router.post("/account/password", dependencies=[Depends(JWTBearer(False))])
async def changePassword(request: Request, passwordData: PasswordChange):
if passwordData.newPassword != passwordData.newPasswordAgain:
@ -69,22 +120,82 @@ async def changePassword(request: Request, passwordData: PasswordChange):
with SessionMaker() as session:
user = session.query(Users).filter(Users.username == username).first()
bPassOld = passwordData.currentPassword.encode("utf8")
try:
nacl.pwhash.scrypt.verify(user.password.encode("utf8"), bPassOld)
except:
if not passwordVerify(user.password, passwordData.currentPassword):
raise HTTPException(status_code=403, detail="Invalid original password")
bPass = passwordData.newPassword.encode("utf8")
user.password = nacl.pwhash.scrypt.str(bPass)
user.temporary = False
# clear sessions and tokens
session.query(Sessions).filter(Sessions.username == user.username).delete()
session.query(Tokens).filter(Tokens.username == user.username).delete()
session.commit()
return JSONResponse({
"success": True
})
@router.post("/account/tokenCheck")
async def tokenVerify(request: Request, tokenData: TokenVerify):
with SessionMaker() as session:
tokens = session.query(Tokens).filter(Tokens.username == str(tokenData.username))
for token in tokens.all():
if passwordVerify(token.hash, tokenData.token):
return JSONResponse({
"success": True,
})
return JSONResponse({
"success": False,
})
@router.get("/account/tokens", dependencies=[Depends(JWTBearer())])
async def getTokens(request: Request):
username = request.state.jwt["account"]
with SessionMaker() as session:
tokens = session.query(Tokens).filter(Tokens.username == str(username))
return JSONResponse({
"tokens": [
{
"id": x.id,
"name": x.name,
"created_at": x.created_at.isoformat()
} for x in tokens.all()
]
})
@router.post("/account/token/create", dependencies=[Depends(JWTBearer())])
async def createToken(request: Request, info: TokenCreate):
username = request.state.jwt["account"]
with SessionMaker() as session:
value = f"{username}-{base64.b64encode(nacl.utils.random(24)).decode('utf8')}"
token = Tokens(username=username, hash=nacl.pwhash.scrypt.str(value.encode("utf8")), name=info.label)
session.add(token)
session.commit()
return JSONResponse({
"token": {
"id": token.id,
"name": token.name,
"created_at": token.created_at.isoformat(),
"token": value
}
})
@router.post("/account/token/delete", dependencies=[Depends(JWTBearer())])
async def deleteToken(request: Request, info: TokenDelete):
username = request.state.jwt["account"]
with SessionMaker() as session:
session.query(Tokens).filter(Tokens.username == username, Tokens.id == info.id).delete()
session.commit()
return JSONResponse({
"success": True
})
@router.post("/account/verify", include_in_schema=False)
async def verify(request: Request):
if request.client.host not in privilegedIps:
@ -95,12 +206,19 @@ async def verify(request: Request):
check = session.query(Users).filter(Users.username == str(body["accountName"]))
first = check.first()
if first:
try:
nacl.pwhash.scrypt.verify(first.password.encode("utf8"), bPass)
# Not too happy with this approach but it should work
sessions = session.query(Sessions).filter(Users.username == str(body["accountName"]))
for sess in sessions.all():
if passwordVerify(sess.hash, body.get("passphrase", "")):
return JSONResponse({
"success": True,
})
except:
if passwordVerify(first.password, body.get("passphrase", "")):
return JSONResponse({
"success": True,
})
else:
return JSONResponse({
"success": False,
"error": "Incorrect password"
@ -123,7 +241,6 @@ async def verify(request: Request):
"success": False,
"error": "Bad invite code"
})
print("invite code", code, "password", password)
account = Users(username=body["accountName"], password=nacl.pwhash.scrypt.str(password), temporary=True)
session.add(account)
session.commit()

View file

@ -1,8 +1,8 @@
from typing import Optional
from typing import List, Optional
from sqlalchemy import CHAR, Index, String, TIMESTAMP, text
from sqlalchemy import CHAR, ForeignKeyConstraint, Index, String, TIMESTAMP, text
from sqlalchemy.dialects.mysql import INTEGER, TINYINT
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
import datetime
class Base(DeclarativeBase):
@ -39,3 +39,38 @@ class Users(Base):
created_at: Mapped[Optional[datetime.datetime]] = mapped_column(TIMESTAMP, server_default=text('current_timestamp()'))
temporary: Mapped[Optional[int]] = mapped_column(TINYINT(1), server_default=text('1'))
invite_code: Mapped[Optional[str]] = mapped_column(String(32))
invited_by: Mapped[Optional[str]] = mapped_column(String(64))
sessions: Mapped[List['Sessions']] = relationship('Sessions', back_populates='users')
tokens: Mapped[List['Tokens']] = relationship('Tokens', back_populates='users')
class Sessions(Base):
__tablename__ = 'sessions'
__table_args__ = (
ForeignKeyConstraint(['username'], ['users.username'], ondelete='CASCADE', name='sessions_ibfk_1'),
Index('ix_sessions_username', 'username')
)
id: Mapped[int] = mapped_column(INTEGER(11), primary_key=True)
username: Mapped[Optional[str]] = mapped_column(String(64))
hash: Mapped[Optional[str]] = mapped_column(String(128))
expiry: Mapped[Optional[datetime.datetime]] = mapped_column(TIMESTAMP)
users: Mapped['Users'] = relationship('Users', back_populates='sessions')
class Tokens(Base):
__tablename__ = 'tokens'
__table_args__ = (
ForeignKeyConstraint(['username'], ['users.username'], ondelete='CASCADE', name='token_username_fk'),
Index('ix_tokens_username', 'username')
)
id: Mapped[int] = mapped_column(INTEGER(11), primary_key=True)
username: Mapped[Optional[str]] = mapped_column(String(64))
hash: Mapped[Optional[str]] = mapped_column(String(128))
name: Mapped[Optional[str]] = mapped_column(String(128))
created_at: Mapped[Optional[datetime.datetime]] = mapped_column(TIMESTAMP, server_default=text('current_timestamp()'))
users: Mapped['Users'] = relationship('Users', back_populates='tokens')

2
run.sh
View file

@ -1,3 +1,3 @@
#!/bin/sh
alembic upgrade head
uvicorn cef_3M:app --port 8001 --host 0.0.0.0
uvicorn cef_3M:app --port 8001 --host 0.0.0.0 --reload