1
0
Fork 0
forked from External/ergo

track channel registrations per account

* limit the total number of registrations per account
* when an account is unregistered, unregister all its channels
This commit is contained in:
Shivaram Lingamneni 2019-02-06 04:32:04 -05:00
parent 8eefe869d0
commit ff7bbc4a9c
7 changed files with 120 additions and 3 deletions

View file

@ -33,6 +33,7 @@ const (
keyAccountEnforcement = "account.customenforcement %s"
keyAccountVHost = "account.vhost %s"
keyCertToAccount = "account.creds.certfp %s"
keyAccountChannels = "account.channels %s"
keyVHostQueueAcctToId = "vhostQueue %s"
vhostRequestIdx = "vhostQueue"
@ -836,9 +837,15 @@ func (am *AccountManager) Unregister(account string) error {
nicksKey := fmt.Sprintf(keyAccountAdditionalNicks, casefoldedAccount)
vhostKey := fmt.Sprintf(keyAccountVHost, casefoldedAccount)
vhostQueueKey := fmt.Sprintf(keyVHostQueueAcctToId, casefoldedAccount)
channelsKey := fmt.Sprintf(keyAccountChannels, casefoldedAccount)
var clients []*Client
var registeredChannels []string
defer func() {
am.server.channelRegistry.deleteByAccount(casefoldedAccount, registeredChannels)
}()
var credText string
var rawNicks string
@ -846,6 +853,7 @@ func (am *AccountManager) Unregister(account string) error {
defer am.serialCacheUpdateMutex.Unlock()
var accountName string
var channelsStr string
am.server.store.Update(func(tx *buntdb.Tx) error {
tx.Delete(accountKey)
accountName, _ = tx.Get(accountNameKey)
@ -859,6 +867,9 @@ func (am *AccountManager) Unregister(account string) error {
credText, err = tx.Get(credentialsKey)
tx.Delete(credentialsKey)
tx.Delete(vhostKey)
channelsStr, _ = tx.Get(channelsKey)
tx.Delete(channelsKey)
_, err := tx.Delete(vhostQueueKey)
am.decrementVHostQueueCount(casefoldedAccount, err)
return nil
@ -879,6 +890,7 @@ func (am *AccountManager) Unregister(account string) error {
skeleton, _ := Skeleton(accountName)
additionalNicks := unmarshalReservedNicks(rawNicks)
registeredChannels = unmarshalRegisteredChannels(channelsStr)
am.Lock()
defer am.Unlock()
@ -899,9 +911,32 @@ func (am *AccountManager) Unregister(account string) error {
if err != nil {
return errAccountDoesNotExist
}
return nil
}
func unmarshalRegisteredChannels(channelsStr string) (result []string) {
if channelsStr != "" {
result = strings.Split(channelsStr, ",")
}
return
}
func (am *AccountManager) ChannelsForAccount(account string) (channels []string) {
cfaccount, err := CasefoldName(account)
if err != nil {
return
}
var channelStr string
key := fmt.Sprintf(keyAccountChannels, cfaccount)
am.server.store.View(func(tx *buntdb.Tx) error {
channelStr, _ = tx.Get(key)
return nil
})
return unmarshalRegisteredChannels(channelStr)
}
func (am *AccountManager) AuthenticateByCertFP(client *Client) error {
if client.certfp == "" {
return errAccountInvalidCredentials