mirror of
https://github.com/ergochat/ergo.git
synced 2025-12-20 02:00:11 -08:00
initial vhosts implementation, #183
This commit is contained in:
parent
40d6cd02da
commit
5e62cc4ebc
15 changed files with 1013 additions and 365 deletions
323
irc/accounts.go
323
irc/accounts.go
|
|
@ -14,6 +14,7 @@ import (
|
|||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/oragono/oragono/irc/caps"
|
||||
|
|
@ -30,14 +31,24 @@ const (
|
|||
keyAccountRegTime = "account.registered.time %s"
|
||||
keyAccountCredentials = "account.credentials %s"
|
||||
keyAccountAdditionalNicks = "account.additionalnicks %s"
|
||||
keyAccountVHost = "account.vhost %s"
|
||||
keyCertToAccount = "account.creds.certfp %s"
|
||||
|
||||
keyVHostQueueAcctToId = "vhostQueue %s"
|
||||
vhostRequestIdx = "vhostQueue"
|
||||
)
|
||||
|
||||
// everything about accounts is persistent; therefore, the database is the authoritative
|
||||
// source of truth for all account information. anything on the heap is just a cache
|
||||
type AccountManager struct {
|
||||
// XXX these are up here so they can be aligned to a 64-bit boundary, please forgive me
|
||||
// autoincrementing ID for vhost requests:
|
||||
vhostRequestID uint64
|
||||
vhostRequestPendingCount uint64
|
||||
|
||||
sync.RWMutex // tier 2
|
||||
serialCacheUpdateMutex sync.Mutex // tier 3
|
||||
vHostUpdateMutex sync.Mutex // tier 3
|
||||
|
||||
server *Server
|
||||
// track clients logged in to accounts
|
||||
|
|
@ -53,6 +64,7 @@ func NewAccountManager(server *Server) *AccountManager {
|
|||
}
|
||||
|
||||
am.buildNickToAccountIndex()
|
||||
am.initVHostRequestQueue()
|
||||
return &am
|
||||
}
|
||||
|
||||
|
|
@ -94,8 +106,44 @@ func (am *AccountManager) buildNickToAccountIndex() {
|
|||
am.nickToAccount = result
|
||||
am.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
func (am *AccountManager) initVHostRequestQueue() {
|
||||
if !am.server.AccountConfig().HostServ.Enabled {
|
||||
return
|
||||
}
|
||||
|
||||
am.vHostUpdateMutex.Lock()
|
||||
defer am.vHostUpdateMutex.Unlock()
|
||||
|
||||
// the db maps the account name to the autoincrementing integer ID of its request
|
||||
// create an numerically ordered index on ID, so we can list the oldest requests
|
||||
// finally, collect the integer id of the newest request and the total request count
|
||||
var total uint64
|
||||
var lastIDStr string
|
||||
err := am.server.store.Update(func(tx *buntdb.Tx) error {
|
||||
err := tx.CreateIndex(vhostRequestIdx, fmt.Sprintf(keyVHostQueueAcctToId, "*"), buntdb.IndexInt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Descend(vhostRequestIdx, func(key, value string) bool {
|
||||
if lastIDStr == "" {
|
||||
lastIDStr = value
|
||||
}
|
||||
total++
|
||||
return true
|
||||
})
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
am.server.logger.Error("internal", "could not create vhost queue index", err.Error())
|
||||
}
|
||||
|
||||
lastID, _ := strconv.ParseUint(lastIDStr, 10, 64)
|
||||
am.server.logger.Debug("services", fmt.Sprintf("vhost queue length is %d, autoincrementing id is %d", total, lastID))
|
||||
|
||||
atomic.StoreUint64(&am.vhostRequestID, lastID)
|
||||
atomic.StoreUint64(&am.vhostRequestPendingCount, total)
|
||||
}
|
||||
|
||||
func (am *AccountManager) NickToAccount(nick string) string {
|
||||
|
|
@ -109,6 +157,17 @@ func (am *AccountManager) NickToAccount(nick string) string {
|
|||
return am.nickToAccount[cfnick]
|
||||
}
|
||||
|
||||
func (am *AccountManager) AccountToClients(account string) (result []*Client) {
|
||||
cfaccount, err := CasefoldName(account)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
am.RLock()
|
||||
defer am.RUnlock()
|
||||
return am.accountToClients[cfaccount]
|
||||
}
|
||||
|
||||
func (am *AccountManager) Register(client *Client, account string, callbackNamespace string, callbackValue string, passphrase string, certfp string) error {
|
||||
casefoldedAccount, err := CasefoldName(account)
|
||||
if err != nil || account == "" || account == "*" {
|
||||
|
|
@ -342,7 +401,12 @@ func (am *AccountManager) Verify(client *Client, account string, code string) er
|
|||
return err
|
||||
}
|
||||
|
||||
am.Login(client, raw.Name)
|
||||
raw.Verified = true
|
||||
clientAccount, err := am.deserializeRawAccount(raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
am.Login(client, clientAccount)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -464,7 +528,7 @@ func (am *AccountManager) AuthenticateByPassphrase(client *Client, accountName s
|
|||
return errAccountInvalidCredentials
|
||||
}
|
||||
|
||||
am.Login(client, account.Name)
|
||||
am.Login(client, account)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -484,6 +548,11 @@ func (am *AccountManager) LoadAccount(accountName string) (result ClientAccount,
|
|||
return
|
||||
}
|
||||
|
||||
result, err = am.deserializeRawAccount(raw)
|
||||
return
|
||||
}
|
||||
|
||||
func (am *AccountManager) deserializeRawAccount(raw rawClientAccount) (result ClientAccount, err error) {
|
||||
result.Name = raw.Name
|
||||
regTimeInt, _ := strconv.ParseInt(raw.RegisteredAt, 10, 64)
|
||||
result.RegisteredAt = time.Unix(regTimeInt, 0)
|
||||
|
|
@ -495,6 +564,13 @@ func (am *AccountManager) LoadAccount(accountName string) (result ClientAccount,
|
|||
}
|
||||
result.AdditionalNicks = unmarshalReservedNicks(raw.AdditionalNicks)
|
||||
result.Verified = raw.Verified
|
||||
if raw.VHost != "" {
|
||||
e := json.Unmarshal([]byte(raw.VHost), &result.VHost)
|
||||
if e != nil {
|
||||
am.server.logger.Warning("internal", fmt.Sprintf("could not unmarshal vhost for account %s: %v", result.Name, e))
|
||||
// pretend they have no vhost and move on
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -506,6 +582,7 @@ func (am *AccountManager) loadRawAccount(tx *buntdb.Tx, casefoldedAccount string
|
|||
verifiedKey := fmt.Sprintf(keyAccountVerified, casefoldedAccount)
|
||||
callbackKey := fmt.Sprintf(keyAccountCallback, casefoldedAccount)
|
||||
nicksKey := fmt.Sprintf(keyAccountAdditionalNicks, casefoldedAccount)
|
||||
vhostKey := fmt.Sprintf(keyAccountVHost, casefoldedAccount)
|
||||
|
||||
_, e := tx.Get(accountKey)
|
||||
if e == buntdb.ErrNotFound {
|
||||
|
|
@ -518,6 +595,7 @@ func (am *AccountManager) loadRawAccount(tx *buntdb.Tx, casefoldedAccount string
|
|||
result.Credentials, _ = tx.Get(credentialsKey)
|
||||
result.Callback, _ = tx.Get(callbackKey)
|
||||
result.AdditionalNicks, _ = tx.Get(nicksKey)
|
||||
result.VHost, _ = tx.Get(vhostKey)
|
||||
|
||||
if _, e = tx.Get(verifiedKey); e == nil {
|
||||
result.Verified = true
|
||||
|
|
@ -540,6 +618,8 @@ func (am *AccountManager) Unregister(account string) error {
|
|||
verificationCodeKey := fmt.Sprintf(keyAccountVerificationCode, casefoldedAccount)
|
||||
verifiedKey := fmt.Sprintf(keyAccountVerified, casefoldedAccount)
|
||||
nicksKey := fmt.Sprintf(keyAccountAdditionalNicks, casefoldedAccount)
|
||||
vhostKey := fmt.Sprintf(keyAccountVHost, casefoldedAccount)
|
||||
vhostQueueKey := fmt.Sprintf(keyVHostQueueAcctToId, casefoldedAccount)
|
||||
|
||||
var clients []*Client
|
||||
|
||||
|
|
@ -560,6 +640,12 @@ func (am *AccountManager) Unregister(account string) error {
|
|||
tx.Delete(nicksKey)
|
||||
credText, err = tx.Get(credentialsKey)
|
||||
tx.Delete(credentialsKey)
|
||||
tx.Delete(vhostKey)
|
||||
_, err := tx.Delete(vhostQueueKey)
|
||||
if err != nil {
|
||||
// 2's complement decrement
|
||||
atomic.AddUint64(&am.vhostRequestPendingCount, ^uint64(0))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
|
|
@ -624,17 +710,226 @@ func (am *AccountManager) AuthenticateByCertFP(client *Client) error {
|
|||
}
|
||||
|
||||
// ok, we found an account corresponding to their certificate
|
||||
|
||||
am.Login(client, rawAccount.Name)
|
||||
clientAccount, err := am.deserializeRawAccount(rawAccount)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
am.Login(client, clientAccount)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *AccountManager) Login(client *Client, account string) {
|
||||
// represents someone's status in hostserv
|
||||
type VHostInfo struct {
|
||||
ApprovedVHost string
|
||||
Enabled bool
|
||||
RequestedVHost string
|
||||
RejectedVHost string
|
||||
RejectionReason string
|
||||
LastRequestTime time.Time
|
||||
}
|
||||
|
||||
// pair type, <VHostInfo, accountName>
|
||||
type PendingVHostRequest struct {
|
||||
VHostInfo
|
||||
Account string
|
||||
}
|
||||
|
||||
// callback type implementing the actual business logic of vhost operations
|
||||
type vhostMunger func(input VHostInfo) (output VHostInfo, err error)
|
||||
|
||||
func (am *AccountManager) VHostSet(account string, vhost string) (err error) {
|
||||
munger := func(input VHostInfo) (output VHostInfo, err error) {
|
||||
output = input
|
||||
output.Enabled = true
|
||||
output.ApprovedVHost = vhost
|
||||
return
|
||||
}
|
||||
|
||||
return am.performVHostChange(account, munger)
|
||||
}
|
||||
|
||||
func (am *AccountManager) VHostRequest(account string, vhost string) (err error) {
|
||||
munger := func(input VHostInfo) (output VHostInfo, err error) {
|
||||
output = input
|
||||
output.RequestedVHost = vhost
|
||||
output.RejectedVHost = ""
|
||||
output.RejectionReason = ""
|
||||
output.LastRequestTime = time.Now().UTC()
|
||||
return
|
||||
}
|
||||
|
||||
return am.performVHostChange(account, munger)
|
||||
}
|
||||
|
||||
func (am *AccountManager) VHostApprove(account string) (err error) {
|
||||
munger := func(input VHostInfo) (output VHostInfo, err error) {
|
||||
output = input
|
||||
output.Enabled = true
|
||||
output.ApprovedVHost = input.RequestedVHost
|
||||
output.RequestedVHost = ""
|
||||
output.RejectionReason = ""
|
||||
return
|
||||
}
|
||||
|
||||
return am.performVHostChange(account, munger)
|
||||
}
|
||||
|
||||
func (am *AccountManager) VHostReject(account string, reason string) (err error) {
|
||||
munger := func(input VHostInfo) (output VHostInfo, err error) {
|
||||
output = input
|
||||
output.RejectedVHost = output.RequestedVHost
|
||||
output.RequestedVHost = ""
|
||||
output.RejectionReason = reason
|
||||
return
|
||||
}
|
||||
|
||||
return am.performVHostChange(account, munger)
|
||||
}
|
||||
|
||||
func (am *AccountManager) VHostSetEnabled(client *Client, enabled bool) (err error) {
|
||||
munger := func(input VHostInfo) (output VHostInfo, err error) {
|
||||
output = input
|
||||
output.Enabled = enabled
|
||||
return
|
||||
}
|
||||
|
||||
return am.performVHostChange(client.Account(), munger)
|
||||
}
|
||||
|
||||
func (am *AccountManager) performVHostChange(account string, munger vhostMunger) (err error) {
|
||||
account, err = CasefoldName(account)
|
||||
if err != nil || account == "" {
|
||||
return errAccountDoesNotExist
|
||||
}
|
||||
|
||||
am.vHostUpdateMutex.Lock()
|
||||
defer am.vHostUpdateMutex.Unlock()
|
||||
|
||||
clientAccount, err := am.LoadAccount(account)
|
||||
if err != nil {
|
||||
return errAccountDoesNotExist
|
||||
} else if !clientAccount.Verified {
|
||||
return errAccountUnverified
|
||||
}
|
||||
|
||||
result, err := munger(clientAccount.VHost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
vhtext, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return errAccountUpdateFailed
|
||||
}
|
||||
vhstr := string(vhtext)
|
||||
|
||||
key := fmt.Sprintf(keyAccountVHost, account)
|
||||
queueKey := fmt.Sprintf(keyVHostQueueAcctToId, account)
|
||||
err = am.server.store.Update(func(tx *buntdb.Tx) error {
|
||||
if _, _, err := tx.Set(key, vhstr, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// update request queue
|
||||
if clientAccount.VHost.RequestedVHost == "" && result.RequestedVHost != "" {
|
||||
id := atomic.AddUint64(&am.vhostRequestID, 1)
|
||||
if _, _, err = tx.Set(queueKey, strconv.FormatUint(id, 10), nil); err != nil {
|
||||
return err
|
||||
}
|
||||
atomic.AddUint64(&am.vhostRequestPendingCount, 1)
|
||||
} else if clientAccount.VHost.RequestedVHost != "" && result.RequestedVHost == "" {
|
||||
_, err = tx.Delete(queueKey)
|
||||
if err != nil {
|
||||
// XXX this is the decrement operation for two's complement
|
||||
atomic.AddUint64(&am.vhostRequestPendingCount, ^uint64(0))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return errAccountUpdateFailed
|
||||
}
|
||||
|
||||
am.applyVhostToClients(account, result)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *AccountManager) VHostListRequests(limit int) (requests []PendingVHostRequest, total int) {
|
||||
am.vHostUpdateMutex.Lock()
|
||||
defer am.vHostUpdateMutex.Unlock()
|
||||
|
||||
total = int(atomic.LoadUint64(&am.vhostRequestPendingCount))
|
||||
|
||||
prefix := fmt.Sprintf(keyVHostQueueAcctToId, "")
|
||||
accounts := make([]string, 0, limit)
|
||||
err := am.server.store.View(func(tx *buntdb.Tx) error {
|
||||
return tx.Ascend(vhostRequestIdx, func(key, value string) bool {
|
||||
accounts = append(accounts, strings.TrimPrefix(key, prefix))
|
||||
return len(accounts) < limit
|
||||
})
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
am.server.logger.Error("internal", "couldn't traverse vhost queue", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
for _, account := range accounts {
|
||||
accountInfo, err := am.LoadAccount(account)
|
||||
if err == nil {
|
||||
requests = append(requests, PendingVHostRequest{
|
||||
Account: account,
|
||||
VHostInfo: accountInfo.VHost,
|
||||
})
|
||||
} else {
|
||||
am.server.logger.Error("internal", "corrupt account", account, err.Error())
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (am *AccountManager) applyVHostInfo(client *Client, info VHostInfo) {
|
||||
// if hostserv is disabled in config, then don't grant vhosts
|
||||
// that were previously approved while it was enabled
|
||||
if !am.server.AccountConfig().HostServ.Enabled {
|
||||
return
|
||||
}
|
||||
|
||||
vhost := ""
|
||||
if info.Enabled {
|
||||
vhost = info.ApprovedVHost
|
||||
}
|
||||
oldNickmask := client.NickMaskString()
|
||||
updated := client.SetVHost(vhost)
|
||||
if updated {
|
||||
// TODO: doing I/O here is kind of a kludge
|
||||
go client.sendChghost(oldNickmask, vhost)
|
||||
}
|
||||
}
|
||||
|
||||
func (am *AccountManager) applyVhostToClients(account string, result VHostInfo) {
|
||||
am.RLock()
|
||||
clients := am.accountToClients[account]
|
||||
am.RUnlock()
|
||||
|
||||
for _, client := range clients {
|
||||
am.applyVHostInfo(client, result)
|
||||
}
|
||||
}
|
||||
|
||||
func (am *AccountManager) Login(client *Client, account ClientAccount) {
|
||||
changed := client.SetAccountName(account.Name)
|
||||
if changed {
|
||||
go client.nickTimer.Touch()
|
||||
}
|
||||
|
||||
am.applyVHostInfo(client, account.VHost)
|
||||
|
||||
casefoldedAccount := client.Account()
|
||||
am.Lock()
|
||||
defer am.Unlock()
|
||||
|
||||
am.loginToAccount(client, account)
|
||||
casefoldedAccount := client.Account()
|
||||
am.accountToClients[casefoldedAccount] = append(am.accountToClients[casefoldedAccount], client)
|
||||
}
|
||||
|
||||
|
|
@ -691,6 +986,7 @@ type ClientAccount struct {
|
|||
Credentials AccountCredentials
|
||||
Verified bool
|
||||
AdditionalNicks []string
|
||||
VHost VHostInfo
|
||||
}
|
||||
|
||||
// convenience for passing around raw serialized account data
|
||||
|
|
@ -701,14 +997,7 @@ type rawClientAccount struct {
|
|||
Callback string
|
||||
Verified bool
|
||||
AdditionalNicks string
|
||||
}
|
||||
|
||||
// loginToAccount logs the client into the given account.
|
||||
func (am *AccountManager) loginToAccount(client *Client, account string) {
|
||||
changed := client.SetAccountName(account)
|
||||
if changed {
|
||||
go client.nickTimer.Touch()
|
||||
}
|
||||
VHost string
|
||||
}
|
||||
|
||||
// logoutOfAccount logs the client out of their current account.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue