remembering logins + tokens

This commit is contained in:
CEF Server 2024-11-28 02:54:10 +00:00
parent 83404812d0
commit 6871aa2449
5 changed files with 242 additions and 19 deletions

View file

@ -0,0 +1,32 @@
"""remember me table
Revision ID: 173d66d27d78
Revises: 3c43e544e939
Create Date: 2024-09-16 00:04:58.835593
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '173d66d27d78'
down_revision: Union[str, None] = '3c43e544e939'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"sessions",
sa.Column("id", sa.INT, autoincrement=True, primary_key=True),
sa.Column("username", sa.VARCHAR(64), sa.ForeignKey("users.username", ondelete="CASCADE"), index=True),
sa.Column("hash", sa.VARCHAR(128)),
sa.Column("expiry", sa.TIMESTAMP(), nullable=True, default=None),
)
def downgrade() -> None:
op.drop_table("sessions")

View file

@ -0,0 +1,39 @@
"""invited by and check tokens
Revision ID: aa17ed273170
Revises: 173d66d27d78
Create Date: 2024-11-24 03:21:32.324284
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'aa17ed273170'
down_revision: Union[str, None] = '173d66d27d78'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"tokens",
sa.Column("id", sa.INTEGER(), primary_key=True, autoincrement=True),
sa.Column("username", sa.VARCHAR(64), index=True),
sa.Column("hash", sa.VARCHAR(128)),
sa.Column("name", sa.VARCHAR(128)),
sa.Column("created_at", sa.TIMESTAMP(), server_default=sa.func.now()),
)
op.create_foreign_key("token_username_fk", "tokens", "users",
["username"], ["username"], ondelete="CASCADE")
op.add_column("users", sa.Column("invited_by", sa.VARCHAR(64), default=None))
def downgrade() -> None:
op.drop_column("users", "invited_by")
op.drop_table("tokens")

View file

@ -1,17 +1,49 @@
import base64
import random import random
from typing import Annotated
from pydantic import BaseModel from pydantic import BaseModel, StringConstraints
from . import router from . import router
from starlette.responses import JSONResponse from starlette.responses import JSONResponse
from fastapi import Request, HTTPException, Depends from fastapi import Request, HTTPException, Depends
from ..sql import SessionMaker, Users from ..sql import SessionMaker, Users, Sessions
from ..sql_generated import Tokens
from ..util import privilegedIps from ..util import privilegedIps
from ..auth import JWTBearer from ..auth import JWTBearer
import nacl.pwhash import nacl.pwhash
import nacl.utils
import string import string
class PasswordChange(BaseModel):
currentPassword: str
newPassword: str
newPasswordAgain: str
class RememberMe(BaseModel):
username: str
password: str
class TokenVerify(BaseModel):
username: str
token: str
class TokenCreate(BaseModel):
label: Annotated[str, StringConstraints(max_length=64)]
class TokenDelete(BaseModel):
id: int
def passwordVerify(hash: str, password: str) -> bool:
try:
nacl.pwhash.scrypt.verify(hash.encode("utf8"), password.encode("utf8"))
return True
except:
return False
@router.get("/account/exists/{name}") @router.get("/account/exists/{name}")
async def exists(name: str): async def exists(name: str):
with SessionMaker() as session: with SessionMaker() as session:
@ -29,10 +61,27 @@ async def exists(name: str):
}) })
class PasswordChange(BaseModel): @router.post("/account/remember")
currentPassword: str async def rememberLogin(request: Request, loginData: RememberMe):
newPassword: str with SessionMaker() as session:
newPasswordAgain: str check = session.query(Users).filter(Users.username == str(loginData.username))
first: Users = check.first()
if not first or not passwordVerify(first.password, loginData.password):
return JSONResponse({
"success": False,
"error": "Incorrect password"
})
token = base64.b64encode(nacl.utils.random(32))
sess = Sessions(username=first.username, hash=nacl.pwhash.scrypt.str(token))
session.add(sess)
session.commit()
return JSONResponse({
"success": True,
"token": token.decode("utf8")
})
@router.get("/account/invite", dependencies=[Depends(JWTBearer())]) @router.get("/account/invite", dependencies=[Depends(JWTBearer())])
async def getInvite(request: Request): async def getInvite(request: Request):
@ -44,6 +93,7 @@ async def getInvite(request: Request):
"code": user.invite_code "code": user.invite_code
}) })
@router.post("/account/invite/regenerate", dependencies=[Depends(JWTBearer())]) @router.post("/account/invite/regenerate", dependencies=[Depends(JWTBearer())])
async def regenInvite(request: Request): async def regenInvite(request: Request):
username = request.state.jwt["account"] username = request.state.jwt["account"]
@ -58,6 +108,7 @@ async def regenInvite(request: Request):
"code": code "code": code
}) })
@router.post("/account/password", dependencies=[Depends(JWTBearer(False))]) @router.post("/account/password", dependencies=[Depends(JWTBearer(False))])
async def changePassword(request: Request, passwordData: PasswordChange): async def changePassword(request: Request, passwordData: PasswordChange):
if passwordData.newPassword != passwordData.newPasswordAgain: if passwordData.newPassword != passwordData.newPasswordAgain:
@ -69,22 +120,82 @@ async def changePassword(request: Request, passwordData: PasswordChange):
with SessionMaker() as session: with SessionMaker() as session:
user = session.query(Users).filter(Users.username == username).first() user = session.query(Users).filter(Users.username == username).first()
bPassOld = passwordData.currentPassword.encode("utf8") if not passwordVerify(user.password, passwordData.currentPassword):
try:
nacl.pwhash.scrypt.verify(user.password.encode("utf8"), bPassOld)
except:
raise HTTPException(status_code=403, detail="Invalid original password") raise HTTPException(status_code=403, detail="Invalid original password")
bPass = passwordData.newPassword.encode("utf8") bPass = passwordData.newPassword.encode("utf8")
user.password = nacl.pwhash.scrypt.str(bPass) user.password = nacl.pwhash.scrypt.str(bPass)
user.temporary = False user.temporary = False
# clear sessions and tokens
session.query(Sessions).filter(Sessions.username == user.username).delete()
session.query(Tokens).filter(Tokens.username == user.username).delete()
session.commit() session.commit()
return JSONResponse({ return JSONResponse({
"success": True "success": True
}) })
@router.post("/account/tokenCheck")
async def tokenVerify(request: Request, tokenData: TokenVerify):
with SessionMaker() as session:
tokens = session.query(Tokens).filter(Tokens.username == str(tokenData.username))
for token in tokens.all():
if passwordVerify(token.hash, tokenData.token):
return JSONResponse({
"success": True,
})
return JSONResponse({
"success": False,
})
@router.get("/account/tokens", dependencies=[Depends(JWTBearer())])
async def getTokens(request: Request):
username = request.state.jwt["account"]
with SessionMaker() as session:
tokens = session.query(Tokens).filter(Tokens.username == str(username))
return JSONResponse({
"tokens": [
{
"id": x.id,
"name": x.name,
"created_at": x.created_at.isoformat()
} for x in tokens.all()
]
})
@router.post("/account/token/create", dependencies=[Depends(JWTBearer())])
async def createToken(request: Request, info: TokenCreate):
username = request.state.jwt["account"]
with SessionMaker() as session:
value = f"{username}-{base64.b64encode(nacl.utils.random(24)).decode('utf8')}"
token = Tokens(username=username, hash=nacl.pwhash.scrypt.str(value.encode("utf8")), name=info.label)
session.add(token)
session.commit()
return JSONResponse({
"token": {
"id": token.id,
"name": token.name,
"created_at": token.created_at.isoformat(),
"token": value
}
})
@router.post("/account/token/delete", dependencies=[Depends(JWTBearer())])
async def deleteToken(request: Request, info: TokenDelete):
username = request.state.jwt["account"]
with SessionMaker() as session:
session.query(Tokens).filter(Tokens.username == username, Tokens.id == info.id).delete()
session.commit()
return JSONResponse({
"success": True
})
@router.post("/account/verify", include_in_schema=False) @router.post("/account/verify", include_in_schema=False)
async def verify(request: Request): async def verify(request: Request):
if request.client.host not in privilegedIps: if request.client.host not in privilegedIps:
@ -95,12 +206,19 @@ async def verify(request: Request):
check = session.query(Users).filter(Users.username == str(body["accountName"])) check = session.query(Users).filter(Users.username == str(body["accountName"]))
first = check.first() first = check.first()
if first: if first:
try: # Not too happy with this approach but it should work
nacl.pwhash.scrypt.verify(first.password.encode("utf8"), bPass) sessions = session.query(Sessions).filter(Users.username == str(body["accountName"]))
for sess in sessions.all():
if passwordVerify(sess.hash, body.get("passphrase", "")):
return JSONResponse({
"success": True,
})
if passwordVerify(first.password, body.get("passphrase", "")):
return JSONResponse({ return JSONResponse({
"success": True, "success": True,
}) })
except: else:
return JSONResponse({ return JSONResponse({
"success": False, "success": False,
"error": "Incorrect password" "error": "Incorrect password"
@ -123,7 +241,6 @@ async def verify(request: Request):
"success": False, "success": False,
"error": "Bad invite code" "error": "Bad invite code"
}) })
print("invite code", code, "password", password)
account = Users(username=body["accountName"], password=nacl.pwhash.scrypt.str(password), temporary=True) account = Users(username=body["accountName"], password=nacl.pwhash.scrypt.str(password), temporary=True)
session.add(account) session.add(account)
session.commit() session.commit()

View file

@ -1,8 +1,8 @@
from typing import Optional from typing import List, Optional
from sqlalchemy import CHAR, Index, String, TIMESTAMP, text from sqlalchemy import CHAR, ForeignKeyConstraint, Index, String, TIMESTAMP, text
from sqlalchemy.dialects.mysql import INTEGER, TINYINT from sqlalchemy.dialects.mysql import INTEGER, TINYINT
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
import datetime import datetime
class Base(DeclarativeBase): class Base(DeclarativeBase):
@ -39,3 +39,38 @@ class Users(Base):
created_at: Mapped[Optional[datetime.datetime]] = mapped_column(TIMESTAMP, server_default=text('current_timestamp()')) created_at: Mapped[Optional[datetime.datetime]] = mapped_column(TIMESTAMP, server_default=text('current_timestamp()'))
temporary: Mapped[Optional[int]] = mapped_column(TINYINT(1), server_default=text('1')) temporary: Mapped[Optional[int]] = mapped_column(TINYINT(1), server_default=text('1'))
invite_code: Mapped[Optional[str]] = mapped_column(String(32)) invite_code: Mapped[Optional[str]] = mapped_column(String(32))
invited_by: Mapped[Optional[str]] = mapped_column(String(64))
sessions: Mapped[List['Sessions']] = relationship('Sessions', back_populates='users')
tokens: Mapped[List['Tokens']] = relationship('Tokens', back_populates='users')
class Sessions(Base):
__tablename__ = 'sessions'
__table_args__ = (
ForeignKeyConstraint(['username'], ['users.username'], ondelete='CASCADE', name='sessions_ibfk_1'),
Index('ix_sessions_username', 'username')
)
id: Mapped[int] = mapped_column(INTEGER(11), primary_key=True)
username: Mapped[Optional[str]] = mapped_column(String(64))
hash: Mapped[Optional[str]] = mapped_column(String(128))
expiry: Mapped[Optional[datetime.datetime]] = mapped_column(TIMESTAMP)
users: Mapped['Users'] = relationship('Users', back_populates='sessions')
class Tokens(Base):
__tablename__ = 'tokens'
__table_args__ = (
ForeignKeyConstraint(['username'], ['users.username'], ondelete='CASCADE', name='token_username_fk'),
Index('ix_tokens_username', 'username')
)
id: Mapped[int] = mapped_column(INTEGER(11), primary_key=True)
username: Mapped[Optional[str]] = mapped_column(String(64))
hash: Mapped[Optional[str]] = mapped_column(String(128))
name: Mapped[Optional[str]] = mapped_column(String(128))
created_at: Mapped[Optional[datetime.datetime]] = mapped_column(TIMESTAMP, server_default=text('current_timestamp()'))
users: Mapped['Users'] = relationship('Users', back_populates='tokens')

2
run.sh
View file

@ -1,3 +1,3 @@
#!/bin/sh #!/bin/sh
alembic upgrade head alembic upgrade head
uvicorn cef_3M:app --port 8001 --host 0.0.0.0 uvicorn cef_3M:app --port 8001 --host 0.0.0.0 --reload