From 7ce06362764ee35629521eacc1fdee5405370efd Mon Sep 17 00:00:00 2001 From: Shivaram Lingamneni Date: Wed, 4 Jan 2023 05:06:21 -0500 Subject: [PATCH] refactor of channel persistence to use UUIDs --- irc/accounts.go | 24 +-- irc/bunt/bunt_datastore.go | 106 +++++++++++ irc/channel.go | 100 +++++----- irc/channelmanager.go | 297 ++++++++++++++++------------- irc/channelreg.go | 381 ++----------------------------------- irc/chanserv.go | 35 ++-- irc/database.go | 159 +++++++++++++--- irc/datastore/datastore.go | 45 +++++ irc/errors.go | 1 + irc/getters.go | 6 + irc/handlers.go | 2 +- irc/hostserv.go | 2 +- irc/import.go | 61 +++--- irc/legacy.go | 121 ++++++++++++ irc/nickserv.go | 4 +- irc/serde.go | 37 ++++ irc/server.go | 20 +- irc/utils/uuid.go | 56 ++++++ 18 files changed, 804 insertions(+), 653 deletions(-) create mode 100644 irc/bunt/bunt_datastore.go create mode 100644 irc/datastore/datastore.go create mode 100644 irc/serde.go create mode 100644 irc/utils/uuid.go diff --git a/irc/accounts.go b/irc/accounts.go index 79646875..181da6df 100644 --- a/irc/accounts.go +++ b/irc/accounts.go @@ -39,7 +39,6 @@ const ( keyAccountSettings = "account.settings %s" keyAccountVHost = "account.vhost %s" keyCertToAccount = "account.creds.certfp %s" - keyAccountChannels = "account.channels %s" // channels registered to the account keyAccountLastSeen = "account.lastseen %s" keyAccountReadMarkers = "account.readmarkers %s" keyAccountModes = "account.modes %s" // user modes for the always-on client as a string @@ -1765,7 +1764,6 @@ func (am *AccountManager) Unregister(account string, erase bool) error { nicksKey := fmt.Sprintf(keyAccountAdditionalNicks, casefoldedAccount) settingsKey := fmt.Sprintf(keyAccountSettings, casefoldedAccount) vhostKey := fmt.Sprintf(keyAccountVHost, casefoldedAccount) - channelsKey := fmt.Sprintf(keyAccountChannels, casefoldedAccount) joinedChannelsKey := fmt.Sprintf(keyAccountChannelToModes, casefoldedAccount) lastSeenKey := fmt.Sprintf(keyAccountLastSeen, casefoldedAccount) readMarkersKey := fmt.Sprintf(keyAccountReadMarkers, casefoldedAccount) @@ -1781,10 +1779,9 @@ func (am *AccountManager) Unregister(account string, erase bool) error { am.killClients(clients) }() - var registeredChannels []string // on our way out, unregister all the account's channels and delete them from the db defer func() { - for _, channelName := range registeredChannels { + for _, channelName := range am.server.channels.ChannelsForAccount(casefoldedAccount) { err := am.server.channels.SetUnregistered(channelName, casefoldedAccount) if err != nil { am.server.logger.Error("internal", "couldn't unregister channel", channelName, err.Error()) @@ -1799,7 +1796,6 @@ func (am *AccountManager) Unregister(account string, erase bool) error { defer am.serialCacheUpdateMutex.Unlock() var accountName string - var channelsStr string keepProtections := false am.server.store.Update(func(tx *buntdb.Tx) error { // get the unfolded account name; for an active account, this is @@ -1827,8 +1823,6 @@ func (am *AccountManager) Unregister(account string, erase bool) error { credText, err = tx.Get(credentialsKey) tx.Delete(credentialsKey) tx.Delete(vhostKey) - channelsStr, _ = tx.Get(channelsKey) - tx.Delete(channelsKey) tx.Delete(joinedChannelsKey) tx.Delete(lastSeenKey) tx.Delete(readMarkersKey) @@ -1858,7 +1852,6 @@ func (am *AccountManager) Unregister(account string, erase bool) error { skeleton, _ := Skeleton(accountName) additionalNicks := unmarshalReservedNicks(rawNicks) - registeredChannels = unmarshalRegisteredChannels(channelsStr) am.Lock() defer am.Unlock() @@ -1890,21 +1883,6 @@ func unmarshalRegisteredChannels(channelsStr string) (result []string) { return } -func (am *AccountManager) ChannelsForAccount(account string) (channels []string) { - cfaccount, err := CasefoldName(account) - if err != nil { - return - } - - var channelStr string - key := fmt.Sprintf(keyAccountChannels, cfaccount) - am.server.store.View(func(tx *buntdb.Tx) error { - channelStr, _ = tx.Get(key) - return nil - }) - return unmarshalRegisteredChannels(channelStr) -} - func (am *AccountManager) AuthenticateByCertificate(client *Client, certfp string, peerCerts []*x509.Certificate, authzid string) (err error) { if certfp == "" { return errAccountInvalidCredentials diff --git a/irc/bunt/bunt_datastore.go b/irc/bunt/bunt_datastore.go new file mode 100644 index 00000000..76c831e3 --- /dev/null +++ b/irc/bunt/bunt_datastore.go @@ -0,0 +1,106 @@ +// Copyright (c) 2022 Shivaram Lingamneni +// released under the MIT license + +package bunt + +import ( + "fmt" + "strings" + "time" + + "github.com/tidwall/buntdb" + + "github.com/ergochat/ergo/irc/datastore" + "github.com/ergochat/ergo/irc/logger" + "github.com/ergochat/ergo/irc/utils" +) + +// BuntKey yields a string key corresponding to a (table, UUID) pair. +// Ideally this would not be public, but some of the migration code +// needs it. +func BuntKey(table datastore.Table, uuid utils.UUID) string { + return fmt.Sprintf("%x %s", table, uuid.String()) +} + +// buntdbDatastore implements datastore.Datastore using a buntdb. +type buntdbDatastore struct { + db *buntdb.DB + logger *logger.Manager +} + +// NewBuntdbDatastore returns a datastore.Datastore backed by buntdb. +func NewBuntdbDatastore(db *buntdb.DB, logger *logger.Manager) datastore.Datastore { + return &buntdbDatastore{ + db: db, + logger: logger, + } +} + +func (b *buntdbDatastore) Backoff() time.Duration { + return 0 +} + +func (b *buntdbDatastore) GetAll(table datastore.Table) (result []datastore.KV, err error) { + tablePrefix := fmt.Sprintf("%x ", table) + err = b.db.View(func(tx *buntdb.Tx) error { + err := tx.AscendGreaterOrEqual("", tablePrefix, func(key, value string) bool { + if !strings.HasPrefix(key, tablePrefix) { + return false + } + uuid, err := utils.DecodeUUID(strings.TrimPrefix(key, tablePrefix)) + if err == nil { + result = append(result, datastore.KV{UUID: uuid, Value: []byte(value)}) + } else { + b.logger.Error("datastore", "invalid uuid", key) + } + return true + }) + return err + }) + return +} + +func (b *buntdbDatastore) Get(table datastore.Table, uuid utils.UUID) (value []byte, err error) { + buntKey := BuntKey(table, uuid) + var result string + err = b.db.View(func(tx *buntdb.Tx) error { + result, err = tx.Get(buntKey) + return err + }) + return []byte(result), err +} + +func (b *buntdbDatastore) Set(table datastore.Table, uuid utils.UUID, value []byte, expiration time.Time) (err error) { + buntKey := BuntKey(table, uuid) + var setOptions *buntdb.SetOptions + if !expiration.IsZero() { + ttl := time.Until(expiration) + if ttl > 0 { + setOptions = &buntdb.SetOptions{Expires: true, TTL: ttl} + } else { + return nil // it already expired, i guess? + } + } + strVal := string(value) + + err = b.db.Update(func(tx *buntdb.Tx) error { + _, _, err := tx.Set(buntKey, strVal, setOptions) + return err + }) + return +} + +func (b *buntdbDatastore) Delete(table datastore.Table, key utils.UUID) (err error) { + buntKey := BuntKey(table, key) + err = b.db.Update(func(tx *buntdb.Tx) error { + _, err := tx.Delete(buntKey) + return err + }) + // deleting a nonexistent key is not considered an error + switch err { + case buntdb.ErrNotFound: + return nil + default: + return err + } +} diff --git a/irc/channel.go b/irc/channel.go index 2cb4e385..ca8e45f1 100644 --- a/irc/channel.go +++ b/irc/channel.go @@ -16,6 +16,7 @@ import ( "github.com/ergochat/irc-go/ircutils" "github.com/ergochat/ergo/irc/caps" + "github.com/ergochat/ergo/irc/datastore" "github.com/ergochat/ergo/irc/history" "github.com/ergochat/ergo/irc/modes" "github.com/ergochat/ergo/irc/utils" @@ -50,14 +51,14 @@ type Channel struct { stateMutex sync.RWMutex // tier 1 writebackLock sync.Mutex // tier 1.5 joinPartMutex sync.Mutex // tier 3 - ensureLoaded utils.Once // manages loading stored registration info from the database dirtyBits uint settings ChannelSettings + uuid utils.UUID } // NewChannel creates a new channel from a `Server` and a `name` // string, which must be unique on the server. -func NewChannel(s *Server, name, casefoldedName string, registered bool) *Channel { +func NewChannel(s *Server, name, casefoldedName string, registered bool, regInfo RegisteredChannel) *Channel { config := s.Config() channel := &Channel{ @@ -71,14 +72,15 @@ func NewChannel(s *Server, name, casefoldedName string, registered bool) *Channe channel.initializeLists() channel.history.Initialize(0, 0) - if !registered { + if registered { + channel.applyRegInfo(regInfo) + } else { channel.resizeHistory(config) for _, mode := range config.Channels.defaultModes { channel.flags.SetMode(mode, true) } - // no loading to do, so "mark" the load operation as "done": - channel.ensureLoaded.Do(func() {}) - } // else: modes will be loaded before first join + channel.uuid = utils.GenerateUUIDv4() + } return channel } @@ -92,24 +94,6 @@ func (channel *Channel) initializeLists() { channel.accountToUMode = make(map[string]modes.Mode) } -// EnsureLoaded blocks until the channel's registration info has been loaded -// from the database. -func (channel *Channel) EnsureLoaded() { - channel.ensureLoaded.Do(func() { - nmc := channel.NameCasefolded() - info, err := channel.server.channelRegistry.LoadChannel(nmc) - if err == nil { - channel.applyRegInfo(info) - } else { - channel.server.logger.Error("internal", "couldn't load channel", nmc, err.Error()) - } - }) -} - -func (channel *Channel) IsLoaded() bool { - return channel.ensureLoaded.Done() -} - func (channel *Channel) resizeHistory(config *Config) { status, _, _ := channel.historyStatus(config) if status == HistoryEphemeral { @@ -126,6 +110,7 @@ func (channel *Channel) applyRegInfo(chanReg RegisteredChannel) { channel.stateMutex.Lock() defer channel.stateMutex.Unlock() + channel.uuid = chanReg.UUID channel.registeredFounder = chanReg.Founder channel.registeredTime = chanReg.RegisteredAt channel.topic = chanReg.Topic @@ -150,38 +135,41 @@ func (channel *Channel) applyRegInfo(chanReg RegisteredChannel) { } // obtain a consistent snapshot of the channel state that can be persisted to the DB -func (channel *Channel) ExportRegistration(includeFlags uint) (info RegisteredChannel) { +func (channel *Channel) ExportRegistration() (info RegisteredChannel) { channel.stateMutex.RLock() defer channel.stateMutex.RUnlock() info.Name = channel.name - info.NameCasefolded = channel.nameCasefolded + info.UUID = channel.uuid info.Founder = channel.registeredFounder info.RegisteredAt = channel.registeredTime - if includeFlags&IncludeTopic != 0 { - info.Topic = channel.topic - info.TopicSetBy = channel.topicSetBy - info.TopicSetTime = channel.topicSetTime - } + info.Topic = channel.topic + info.TopicSetBy = channel.topicSetBy + info.TopicSetTime = channel.topicSetTime - if includeFlags&IncludeModes != 0 { - info.Key = channel.key - info.Forward = channel.forward - info.Modes = channel.flags.AllModes() - info.UserLimit = channel.userLimit - } + info.Key = channel.key + info.Forward = channel.forward + info.Modes = channel.flags.AllModes() + info.UserLimit = channel.userLimit - if includeFlags&IncludeLists != 0 { - info.Bans = channel.lists[modes.BanMask].Masks() - info.Invites = channel.lists[modes.InviteMask].Masks() - info.Excepts = channel.lists[modes.ExceptMask].Masks() - info.AccountToUMode = utils.CopyMap(channel.accountToUMode) - } + info.Bans = channel.lists[modes.BanMask].Masks() + info.Invites = channel.lists[modes.InviteMask].Masks() + info.Excepts = channel.lists[modes.ExceptMask].Masks() + info.AccountToUMode = utils.CopyMap(channel.accountToUMode) - if includeFlags&IncludeSettings != 0 { - info.Settings = channel.settings - } + info.Settings = channel.settings + + return +} + +func (channel *Channel) exportSummary() (info RegisteredChannel) { + channel.stateMutex.RLock() + defer channel.stateMutex.RUnlock() + + info.Name = channel.name + info.Founder = channel.registeredFounder + info.RegisteredAt = channel.registeredTime return } @@ -288,9 +276,19 @@ func (channel *Channel) performWrite(additionalDirtyBits uint) (err error) { return } - info := channel.ExportRegistration(dirtyBits) - err = channel.server.channelRegistry.StoreChannel(info, dirtyBits) - if err != nil { + var success bool + info := channel.ExportRegistration() + if b, err := info.Serialize(); err == nil { + if err := channel.server.dstore.Set(datastore.TableChannels, info.UUID, b, time.Time{}); err == nil { + success = true + } else { + channel.server.logger.Error("internal", "couldn't persist channel", info.Name, err.Error()) + } + } else { + channel.server.logger.Error("internal", "couldn't serialize channel", info.Name, err.Error()) + } + + if !success { channel.stateMutex.Lock() channel.dirtyBits = channel.dirtyBits | dirtyBits channel.stateMutex.Unlock() @@ -314,6 +312,7 @@ func (channel *Channel) SetRegistered(founder string) error { // SetUnregistered deletes the channel's registration information. func (channel *Channel) SetUnregistered(expectedFounder string) { + uuid := utils.GenerateUUIDv4() channel.stateMutex.Lock() defer channel.stateMutex.Unlock() @@ -324,6 +323,9 @@ func (channel *Channel) SetUnregistered(expectedFounder string) { var zeroTime time.Time channel.registeredTime = zeroTime channel.accountToUMode = make(map[string]modes.Mode) + // reset the UUID so that any re-registration will persist under + // a separate key: + channel.uuid = uuid } // implements `CHANSERV CLEAR #chan ACCESS` (resets bans, invites, excepts, and amodes) diff --git a/irc/channelmanager.go b/irc/channelmanager.go index a1921c4b..5934ab43 100644 --- a/irc/channelmanager.go +++ b/irc/channelmanager.go @@ -6,7 +6,9 @@ package irc import ( "sort" "sync" + "time" + "github.com/ergochat/ergo/irc/datastore" "github.com/ergochat/ergo/irc/utils" ) @@ -25,85 +27,75 @@ type channelManagerEntry struct { type ChannelManager struct { sync.RWMutex // tier 2 // chans is the main data structure, mapping casefolded name -> *Channel - chans map[string]*channelManagerEntry - chansSkeletons utils.HashSet[string] // skeletons of *unregistered* chans - registeredChannels utils.HashSet[string] // casefolds of registered chans - registeredSkeletons utils.HashSet[string] // skeletons of registered chans - purgedChannels utils.HashSet[string] // casefolds of purged chans - server *Server + chans map[string]*channelManagerEntry + chansSkeletons utils.HashSet[string] + purgedChannels map[string]ChannelPurgeRecord // casefolded name to purge record + server *Server } // NewChannelManager returns a new ChannelManager. -func (cm *ChannelManager) Initialize(server *Server) { +func (cm *ChannelManager) Initialize(server *Server, config *Config) (err error) { cm.chans = make(map[string]*channelManagerEntry) cm.chansSkeletons = make(utils.HashSet[string]) cm.server = server - - // purging should work even if registration is disabled - cm.purgedChannels = cm.server.channelRegistry.PurgedChannels() - cm.loadRegisteredChannels(server.Config()) + return cm.loadRegisteredChannels(config) } -func (cm *ChannelManager) loadRegisteredChannels(config *Config) { - if !config.Channels.Registration.Enabled { +func (cm *ChannelManager) loadRegisteredChannels(config *Config) (err error) { + allChannels, err := FetchAndDeserializeAll[RegisteredChannel](datastore.TableChannels, cm.server.dstore, cm.server.logger) + if err != nil { + return + } + allPurgeRecords, err := FetchAndDeserializeAll[ChannelPurgeRecord](datastore.TableChannelPurges, cm.server.dstore, cm.server.logger) + if err != nil { return } - - var newChannels []*Channel - var collisions []string - defer func() { - for _, ch := range newChannels { - ch.EnsureLoaded() - cm.server.logger.Debug("channels", "initialized registered channel", ch.Name()) - } - for _, collision := range collisions { - cm.server.logger.Warning("channels", "registered channel collides with existing channel", collision) - } - }() - - rawNames := cm.server.channelRegistry.AllChannels() cm.Lock() defer cm.Unlock() - cm.registeredChannels = make(utils.HashSet[string], len(rawNames)) - cm.registeredSkeletons = make(utils.HashSet[string], len(rawNames)) - for _, name := range rawNames { - cfname, err := CasefoldChannel(name) - if err == nil { - cm.registeredChannels.Add(cfname) + cm.purgedChannels = make(map[string]ChannelPurgeRecord, len(allPurgeRecords)) + for _, purge := range allPurgeRecords { + cm.purgedChannels[purge.NameCasefolded] = purge + } + + for _, regInfo := range allChannels { + cfname, err := CasefoldChannel(regInfo.Name) + if err != nil { + cm.server.logger.Error("channels", "couldn't casefold registered channel, skipping", regInfo.Name, err.Error()) + continue + } else { + cm.server.logger.Debug("channels", "initializing registered channel", regInfo.Name) } - skeleton, err := Skeleton(name) + skeleton, err := Skeleton(regInfo.Name) if err == nil { - cm.registeredSkeletons.Add(skeleton) + cm.chansSkeletons.Add(skeleton) } - if !cm.purgedChannels.Has(cfname) { - if _, ok := cm.chans[cfname]; !ok { - ch := NewChannel(cm.server, name, cfname, true) - cm.chans[cfname] = &channelManagerEntry{ - channel: ch, - pendingJoins: 0, - } - newChannels = append(newChannels, ch) - } else { - collisions = append(collisions, name) + if _, ok := cm.purgedChannels[cfname]; !ok { + ch := NewChannel(cm.server, regInfo.Name, cfname, true, regInfo) + cm.chans[cfname] = &channelManagerEntry{ + channel: ch, + pendingJoins: 0, + skeleton: skeleton, } } } + + return nil } // Get returns an existing channel with name equivalent to `name`, or nil func (cm *ChannelManager) Get(name string) (channel *Channel) { name, err := CasefoldChannel(name) - if err == nil { - cm.RLock() - defer cm.RUnlock() - entry := cm.chans[name] - // if the channel is still loading, pretend we don't have it - if entry != nil && entry.channel.IsLoaded() { - return entry.channel - } + if err != nil { + return nil + } + cm.RLock() + defer cm.RUnlock() + entry := cm.chans[name] + if entry != nil { + return entry.channel } return nil } @@ -122,33 +114,26 @@ func (cm *ChannelManager) Join(client *Client, name string, key string, isSajoin cm.Lock() defer cm.Unlock() - if cm.purgedChannels.Has(casefoldedName) { + // check purges first; a registered purged channel will still be present in `chans` + if _, ok := cm.purgedChannels[casefoldedName]; ok { return nil, errChannelPurged, false } entry := cm.chans[casefoldedName] if entry == nil { - registered := cm.registeredChannels.Has(casefoldedName) - // enforce OpOnlyCreation - if !registered && server.Config().Channels.OpOnlyCreation && + if server.Config().Channels.OpOnlyCreation && !(isSajoin || client.HasRoleCapabs("chanreg")) { return nil, errInsufficientPrivs, false } // enforce confusables - if !registered && (cm.chansSkeletons.Has(skeleton) || cm.registeredSkeletons.Has(skeleton)) { + if cm.chansSkeletons.Has(skeleton) { return nil, errConfusableIdentifier, false } entry = &channelManagerEntry{ - channel: NewChannel(server, name, casefoldedName, registered), + channel: NewChannel(server, name, casefoldedName, false, RegisteredChannel{}), pendingJoins: 0, } - if !registered { - // for an unregistered channel, we already have the correct unfolded name - // and therefore the final skeleton. for a registered channel, we don't have - // the unfolded name yet (it needs to be loaded from the db), but we already - // have the final skeleton in `registeredSkeletons` so we don't need to track it - cm.chansSkeletons.Add(skeleton) - entry.skeleton = skeleton - } + cm.chansSkeletons.Add(skeleton) + entry.skeleton = skeleton cm.chans[casefoldedName] = entry newChannel = true } @@ -160,7 +145,6 @@ func (cm *ChannelManager) Join(client *Client, name string, key string, isSajoin return err, "" } - channel.EnsureLoaded() err, forward = channel.Join(client, key, isSajoin || newChannel, rb) cm.maybeCleanup(channel, true) @@ -252,13 +236,6 @@ func (cm *ChannelManager) SetRegistered(channelName string, account string) (err if err != nil { return err } - // transfer the skeleton from chansSkeletons to registeredSkeletons - skeleton := entry.skeleton - delete(cm.chansSkeletons, skeleton) - entry.skeleton = "" - cm.chans[cfname] = entry - cm.registeredChannels.Add(cfname) - cm.registeredSkeletons.Add(skeleton) return nil } @@ -268,17 +245,13 @@ func (cm *ChannelManager) SetUnregistered(channelName string, account string) (e return err } - info, err := cm.server.channelRegistry.LoadChannel(cfname) - if err != nil { - return err - } - if info.Founder != account { - return errChannelNotOwnedByAccount - } + var uuid utils.UUID defer func() { if err == nil { - err = cm.server.channelRegistry.Delete(info) + if delErr := cm.server.dstore.Delete(datastore.TableChannels, uuid); delErr != nil { + cm.server.logger.Error("datastore", "couldn't delete channel registration", cfname, delErr.Error()) + } } }() @@ -286,15 +259,11 @@ func (cm *ChannelManager) SetUnregistered(channelName string, account string) (e defer cm.Unlock() entry := cm.chans[cfname] if entry != nil { - entry.channel.SetUnregistered(account) - delete(cm.registeredChannels, cfname) - // transfer the skeleton from registeredSkeletons to chansSkeletons - if skel, err := Skeleton(entry.channel.Name()); err == nil { - delete(cm.registeredSkeletons, skel) - cm.chansSkeletons.Add(skel) - entry.skeleton = skel - cm.chans[cfname] = entry + if entry.channel.Founder() != account { + return errChannelNotOwnedByAccount } + uuid = entry.channel.UUID() + entry.channel.SetUnregistered(account) // changes the UUID // #1619: if the channel has 0 members and was only being retained // because it was registered, clean it up: cm.maybeCleanupInternal(cfname, entry, false) @@ -322,12 +291,11 @@ func (cm *ChannelManager) Rename(name string, newName string) (err error) { var info RegisteredChannel defer func() { if channel != nil && info.Founder != "" { - channel.Store(IncludeAllAttrs) - if oldCfname != newCfname { - // we just flushed the channel under its new name, therefore this delete - // cannot be overwritten by a write to the old name: - cm.server.channelRegistry.Delete(info) - } + channel.MarkDirty(IncludeAllAttrs) + } + // always-on clients need to update their saved channel memberships + for _, member := range channel.Members() { + member.markDirty(IncludeChannels) } }() @@ -335,11 +303,11 @@ func (cm *ChannelManager) Rename(name string, newName string) (err error) { defer cm.Unlock() entry := cm.chans[oldCfname] - if entry == nil || !entry.channel.IsLoaded() { + if entry == nil { return errNoSuchChannel } channel = entry.channel - info = channel.ExportRegistration(IncludeInitial) + info = channel.ExportRegistration() registered := info.Founder != "" oldSkeleton, err := Skeleton(info.Name) @@ -348,13 +316,13 @@ func (cm *ChannelManager) Rename(name string, newName string) (err error) { } if newCfname != oldCfname { - if cm.chans[newCfname] != nil || cm.registeredChannels.Has(newCfname) { + if cm.chans[newCfname] != nil { return errChannelNameInUse } } if oldSkeleton != newSkeleton { - if cm.chansSkeletons.Has(newSkeleton) || cm.registeredSkeletons.Has(newSkeleton) { + if cm.chansSkeletons.Has(newSkeleton) { return errConfusableIdentifier } } @@ -364,15 +332,8 @@ func (cm *ChannelManager) Rename(name string, newName string) (err error) { entry.skeleton = newSkeleton } cm.chans[newCfname] = entry - if registered { - delete(cm.registeredChannels, oldCfname) - cm.registeredChannels.Add(newCfname) - delete(cm.registeredSkeletons, oldSkeleton) - cm.registeredSkeletons.Add(newSkeleton) - } else { - delete(cm.chansSkeletons, oldSkeleton) - cm.chansSkeletons.Add(newSkeleton) - } + delete(cm.chansSkeletons, oldSkeleton) + cm.chansSkeletons.Add(newSkeleton) entry.channel.Rename(newName, newCfname) return nil } @@ -390,7 +351,18 @@ func (cm *ChannelManager) Channels() (result []*Channel) { defer cm.RUnlock() result = make([]*Channel, 0, len(cm.chans)) for _, entry := range cm.chans { - if entry.channel.IsLoaded() { + result = append(result, entry.channel) + } + return +} + +// ListableChannels returns a slice of all non-purged channels. +func (cm *ChannelManager) ListableChannels() (result []*Channel) { + cm.RLock() + defer cm.RUnlock() + result = make([]*Channel, 0, len(cm.chans)) + for cfname, entry := range cm.chans { + if _, ok := cm.purgedChannels[cfname]; !ok { result = append(result, entry.channel) } } @@ -403,29 +375,46 @@ func (cm *ChannelManager) Purge(chname string, record ChannelPurgeRecord) (err e if err != nil { return errInvalidChannelName } - skel, err := Skeleton(chname) - if err != nil { - return errInvalidChannelName - } - cm.Lock() - cm.purgedChannels.Add(chname) - entry := cm.chans[chname] - if entry != nil { - delete(cm.chans, chname) - if entry.channel.Founder() != "" { - delete(cm.registeredSkeletons, skel) - } else { - delete(cm.chansSkeletons, skel) + record.NameCasefolded = chname + record.UUID = utils.GenerateUUIDv4() + + channel, err := func() (channel *Channel, err error) { + cm.Lock() + defer cm.Unlock() + + if _, ok := cm.purgedChannels[chname]; ok { + return nil, errChannelPurgedAlready } - } - cm.Unlock() - cm.server.channelRegistry.PurgeChannel(chname, record) - if entry != nil { - entry.channel.Purge("") + entry := cm.chans[chname] + // atomically prevent anyone from rejoining + cm.purgedChannels[chname] = record + if entry != nil { + channel = entry.channel + } + return + }() + + if err != nil { + return err } - return nil + + if channel != nil { + // actually kick everyone off the channel + channel.Purge("") + } + + var purgeBytes []byte + if purgeBytes, err = record.Serialize(); err != nil { + cm.server.logger.Error("internal", "couldn't serialize purge record", channel.Name(), err.Error()) + } + // TODO we need a better story about error handling for later + if err = cm.server.dstore.Set(datastore.TableChannelPurges, record.UUID, purgeBytes, time.Time{}); err != nil { + cm.server.logger.Error("datastore", "couldn't store purge record", chname, err.Error()) + } + + return } // IsPurged queries whether a channel is purged. @@ -436,7 +425,7 @@ func (cm *ChannelManager) IsPurged(chname string) (result bool) { } cm.RLock() - result = cm.purgedChannels.Has(chname) + _, result = cm.purgedChannels[chname] cm.RUnlock() return } @@ -449,14 +438,16 @@ func (cm *ChannelManager) Unpurge(chname string) (err error) { } cm.Lock() - found := cm.purgedChannels.Has(chname) + record, found := cm.purgedChannels[chname] delete(cm.purgedChannels, chname) cm.Unlock() - cm.server.channelRegistry.UnpurgeChannel(chname) if !found { return errNoSuchChannel } + if err := cm.server.dstore.Delete(datastore.TableChannelPurges, record.UUID); err != nil { + cm.server.logger.Error("datastore", "couldn't delete purge record", chname, err.Error()) + } return nil } @@ -475,8 +466,46 @@ func (cm *ChannelManager) UnfoldName(cfname string) (result string) { cm.RLock() entry := cm.chans[cfname] cm.RUnlock() - if entry != nil && entry.channel.IsLoaded() { + if entry != nil { return entry.channel.Name() } return cfname } + +func (cm *ChannelManager) LoadPurgeRecord(cfchname string) (record ChannelPurgeRecord, err error) { + cm.RLock() + defer cm.RUnlock() + + if record, ok := cm.purgedChannels[cfchname]; ok { + return record, nil + } else { + return record, errNoSuchChannel + } +} + +func (cm *ChannelManager) ChannelsForAccount(account string) (channels []string) { + cm.RLock() + defer cm.RUnlock() + + for cfname, entry := range cm.chans { + if entry.channel.Founder() == account { + channels = append(channels, cfname) + } + } + + return +} + +// AllChannels returns the uncasefolded names of all registered channels. +func (cm *ChannelManager) AllRegisteredChannels() (result []string) { + cm.RLock() + defer cm.RUnlock() + + for cfname, entry := range cm.chans { + if entry.channel.Founder() != "" { + result = append(result, cfname) + } + } + + return +} diff --git a/irc/channelreg.go b/irc/channelreg.go index bb0a851b..1978b4ef 100644 --- a/irc/channelreg.go +++ b/irc/channelreg.go @@ -5,13 +5,8 @@ package irc import ( "encoding/json" - "fmt" - "strconv" - "strings" "time" - "github.com/tidwall/buntdb" - "github.com/ergochat/ergo/irc/modes" "github.com/ergochat/ergo/irc/utils" ) @@ -19,48 +14,6 @@ import ( // this is exclusively the *persistence* layer for channel registration; // channel creation/tracking/destruction is in channelmanager.go -const ( - keyChannelExists = "channel.exists %s" - keyChannelName = "channel.name %s" // stores the 'preferred name' of the channel, not casemapped - keyChannelRegTime = "channel.registered.time %s" - keyChannelFounder = "channel.founder %s" - keyChannelTopic = "channel.topic %s" - keyChannelTopicSetBy = "channel.topic.setby %s" - keyChannelTopicSetTime = "channel.topic.settime %s" - keyChannelBanlist = "channel.banlist %s" - keyChannelExceptlist = "channel.exceptlist %s" - keyChannelInvitelist = "channel.invitelist %s" - keyChannelPassword = "channel.key %s" - keyChannelModes = "channel.modes %s" - keyChannelAccountToUMode = "channel.accounttoumode %s" - keyChannelUserLimit = "channel.userlimit %s" - keyChannelSettings = "channel.settings %s" - keyChannelForward = "channel.forward %s" - - keyChannelPurged = "channel.purged %s" -) - -var ( - channelKeyStrings = []string{ - keyChannelExists, - keyChannelName, - keyChannelRegTime, - keyChannelFounder, - keyChannelTopic, - keyChannelTopicSetBy, - keyChannelTopicSetTime, - keyChannelBanlist, - keyChannelExceptlist, - keyChannelInvitelist, - keyChannelPassword, - keyChannelModes, - keyChannelAccountToUMode, - keyChannelUserLimit, - keyChannelSettings, - keyChannelForward, - } -) - // these are bit flags indicating what part of the channel status is "dirty" // and needs to be read from memory and written to the db const ( @@ -80,8 +33,8 @@ const ( type RegisteredChannel struct { // Name of the channel. Name string - // Casefolded name of the channel. - NameCasefolded string + // UUID for the datastore. + UUID utils.UUID // RegisteredAt represents the time that the channel was registered. RegisteredAt time.Time // Founder indicates the founder of the channel. @@ -112,322 +65,26 @@ type RegisteredChannel struct { Settings ChannelSettings } +func (r *RegisteredChannel) Serialize() ([]byte, error) { + return json.Marshal(r) +} + +func (r *RegisteredChannel) Deserialize(b []byte) (err error) { + return json.Unmarshal(b, r) +} + type ChannelPurgeRecord struct { - Oper string - PurgedAt time.Time - Reason string + NameCasefolded string `json:"Name"` + UUID utils.UUID + Oper string + PurgedAt time.Time + Reason string } -// ChannelRegistry manages registered channels. -type ChannelRegistry struct { - server *Server +func (c *ChannelPurgeRecord) Serialize() ([]byte, error) { + return json.Marshal(c) } -// NewChannelRegistry returns a new ChannelRegistry. -func (reg *ChannelRegistry) Initialize(server *Server) { - reg.server = server -} - -// AllChannels returns the uncasefolded names of all registered channels. -func (reg *ChannelRegistry) AllChannels() (result []string) { - prefix := fmt.Sprintf(keyChannelName, "") - reg.server.store.View(func(tx *buntdb.Tx) error { - return tx.AscendGreaterOrEqual("", prefix, func(key, value string) bool { - if !strings.HasPrefix(key, prefix) { - return false - } - result = append(result, value) - return true - }) - }) - - return -} - -// PurgedChannels returns the set of all casefolded channel names that have been purged -func (reg *ChannelRegistry) PurgedChannels() (result utils.HashSet[string]) { - result = make(utils.HashSet[string]) - - prefix := fmt.Sprintf(keyChannelPurged, "") - reg.server.store.View(func(tx *buntdb.Tx) error { - return tx.AscendGreaterOrEqual("", prefix, func(key, value string) bool { - if !strings.HasPrefix(key, prefix) { - return false - } - channel := strings.TrimPrefix(key, prefix) - result.Add(channel) - return true - }) - }) - return -} - -// StoreChannel obtains a consistent view of a channel, then persists it to the store. -func (reg *ChannelRegistry) StoreChannel(info RegisteredChannel, includeFlags uint) (err error) { - if !reg.server.ChannelRegistrationEnabled() { - return - } - - if info.Founder == "" { - // sanity check, don't try to store an unregistered channel - return - } - - reg.server.store.Update(func(tx *buntdb.Tx) error { - reg.saveChannel(tx, info, includeFlags) - return nil - }) - - return nil -} - -// LoadChannel loads a channel from the store. -func (reg *ChannelRegistry) LoadChannel(nameCasefolded string) (info RegisteredChannel, err error) { - if !reg.server.ChannelRegistrationEnabled() { - err = errFeatureDisabled - return - } - - channelKey := nameCasefolded - // nice to have: do all JSON (de)serialization outside of the buntdb transaction - err = reg.server.store.View(func(tx *buntdb.Tx) error { - _, dberr := tx.Get(fmt.Sprintf(keyChannelExists, channelKey)) - if dberr == buntdb.ErrNotFound { - // chan does not already exist, return - return errNoSuchChannel - } - - // channel exists, load it - name, _ := tx.Get(fmt.Sprintf(keyChannelName, channelKey)) - regTime, _ := tx.Get(fmt.Sprintf(keyChannelRegTime, channelKey)) - regTimeInt, _ := strconv.ParseInt(regTime, 10, 64) - founder, _ := tx.Get(fmt.Sprintf(keyChannelFounder, channelKey)) - topic, _ := tx.Get(fmt.Sprintf(keyChannelTopic, channelKey)) - topicSetBy, _ := tx.Get(fmt.Sprintf(keyChannelTopicSetBy, channelKey)) - var topicSetTime time.Time - topicSetTimeStr, _ := tx.Get(fmt.Sprintf(keyChannelTopicSetTime, channelKey)) - if topicSetTimeInt, topicSetTimeErr := strconv.ParseInt(topicSetTimeStr, 10, 64); topicSetTimeErr == nil { - topicSetTime = time.Unix(0, topicSetTimeInt).UTC() - } - password, _ := tx.Get(fmt.Sprintf(keyChannelPassword, channelKey)) - modeString, _ := tx.Get(fmt.Sprintf(keyChannelModes, channelKey)) - userLimitString, _ := tx.Get(fmt.Sprintf(keyChannelUserLimit, channelKey)) - forward, _ := tx.Get(fmt.Sprintf(keyChannelForward, channelKey)) - banlistString, _ := tx.Get(fmt.Sprintf(keyChannelBanlist, channelKey)) - exceptlistString, _ := tx.Get(fmt.Sprintf(keyChannelExceptlist, channelKey)) - invitelistString, _ := tx.Get(fmt.Sprintf(keyChannelInvitelist, channelKey)) - accountToUModeString, _ := tx.Get(fmt.Sprintf(keyChannelAccountToUMode, channelKey)) - settingsString, _ := tx.Get(fmt.Sprintf(keyChannelSettings, channelKey)) - - modeSlice := make([]modes.Mode, len(modeString)) - for i, mode := range modeString { - modeSlice[i] = modes.Mode(mode) - } - - userLimit, _ := strconv.Atoi(userLimitString) - - var banlist map[string]MaskInfo - _ = json.Unmarshal([]byte(banlistString), &banlist) - var exceptlist map[string]MaskInfo - _ = json.Unmarshal([]byte(exceptlistString), &exceptlist) - var invitelist map[string]MaskInfo - _ = json.Unmarshal([]byte(invitelistString), &invitelist) - accountToUMode := make(map[string]modes.Mode) - _ = json.Unmarshal([]byte(accountToUModeString), &accountToUMode) - - var settings ChannelSettings - _ = json.Unmarshal([]byte(settingsString), &settings) - - info = RegisteredChannel{ - Name: name, - NameCasefolded: nameCasefolded, - RegisteredAt: time.Unix(0, regTimeInt).UTC(), - Founder: founder, - Topic: topic, - TopicSetBy: topicSetBy, - TopicSetTime: topicSetTime, - Key: password, - Modes: modeSlice, - Bans: banlist, - Excepts: exceptlist, - Invites: invitelist, - AccountToUMode: accountToUMode, - UserLimit: int(userLimit), - Settings: settings, - Forward: forward, - } - return nil - }) - - return -} - -// Delete deletes a channel corresponding to `info`. If no such channel -// is present in the database, no error is returned. -func (reg *ChannelRegistry) Delete(info RegisteredChannel) (err error) { - if !reg.server.ChannelRegistrationEnabled() { - return - } - - reg.server.store.Update(func(tx *buntdb.Tx) error { - reg.deleteChannel(tx, info.NameCasefolded, info) - return nil - }) - return nil -} - -// delete a channel, unless it was overwritten by another registration of the same channel -func (reg *ChannelRegistry) deleteChannel(tx *buntdb.Tx, key string, info RegisteredChannel) { - _, err := tx.Get(fmt.Sprintf(keyChannelExists, key)) - if err == nil { - regTime, _ := tx.Get(fmt.Sprintf(keyChannelRegTime, key)) - regTimeInt, _ := strconv.ParseInt(regTime, 10, 64) - registeredAt := time.Unix(0, regTimeInt).UTC() - founder, _ := tx.Get(fmt.Sprintf(keyChannelFounder, key)) - - // to see if we're deleting the right channel, confirm the founder and the registration time - if founder == info.Founder && registeredAt.Equal(info.RegisteredAt) { - for _, keyFmt := range channelKeyStrings { - tx.Delete(fmt.Sprintf(keyFmt, key)) - } - - // remove this channel from the client's list of registered channels - channelsKey := fmt.Sprintf(keyAccountChannels, info.Founder) - channelsStr, err := tx.Get(channelsKey) - if err == buntdb.ErrNotFound { - return - } - registeredChannels := unmarshalRegisteredChannels(channelsStr) - var nowRegisteredChannels []string - for _, channel := range registeredChannels { - if channel != key { - nowRegisteredChannels = append(nowRegisteredChannels, channel) - } - } - tx.Set(channelsKey, strings.Join(nowRegisteredChannels, ","), nil) - } - } -} - -func (reg *ChannelRegistry) updateAccountToChannelMapping(tx *buntdb.Tx, channelInfo RegisteredChannel) { - channelKey := channelInfo.NameCasefolded - chanFounderKey := fmt.Sprintf(keyChannelFounder, channelKey) - founder, existsErr := tx.Get(chanFounderKey) - if existsErr == buntdb.ErrNotFound || founder != channelInfo.Founder { - // add to new founder's list - accountChannelsKey := fmt.Sprintf(keyAccountChannels, channelInfo.Founder) - alreadyChannels, _ := tx.Get(accountChannelsKey) - newChannels := channelKey // this is the casefolded channel name - if alreadyChannels != "" { - newChannels = fmt.Sprintf("%s,%s", alreadyChannels, newChannels) - } - tx.Set(accountChannelsKey, newChannels, nil) - } - if existsErr == nil && founder != channelInfo.Founder { - // remove from old founder's list - accountChannelsKey := fmt.Sprintf(keyAccountChannels, founder) - alreadyChannelsRaw, _ := tx.Get(accountChannelsKey) - var newChannels []string - if alreadyChannelsRaw != "" { - for _, chname := range strings.Split(alreadyChannelsRaw, ",") { - if chname != channelInfo.NameCasefolded { - newChannels = append(newChannels, chname) - } - } - } - tx.Set(accountChannelsKey, strings.Join(newChannels, ","), nil) - } -} - -// saveChannel saves a channel to the store. -func (reg *ChannelRegistry) saveChannel(tx *buntdb.Tx, channelInfo RegisteredChannel, includeFlags uint) { - channelKey := channelInfo.NameCasefolded - // maintain the mapping of account -> registered channels - reg.updateAccountToChannelMapping(tx, channelInfo) - - if includeFlags&IncludeInitial != 0 { - tx.Set(fmt.Sprintf(keyChannelExists, channelKey), "1", nil) - tx.Set(fmt.Sprintf(keyChannelName, channelKey), channelInfo.Name, nil) - tx.Set(fmt.Sprintf(keyChannelRegTime, channelKey), strconv.FormatInt(channelInfo.RegisteredAt.UnixNano(), 10), nil) - tx.Set(fmt.Sprintf(keyChannelFounder, channelKey), channelInfo.Founder, nil) - } - - if includeFlags&IncludeTopic != 0 { - tx.Set(fmt.Sprintf(keyChannelTopic, channelKey), channelInfo.Topic, nil) - var topicSetTimeStr string - if !channelInfo.TopicSetTime.IsZero() { - topicSetTimeStr = strconv.FormatInt(channelInfo.TopicSetTime.UnixNano(), 10) - } - tx.Set(fmt.Sprintf(keyChannelTopicSetTime, channelKey), topicSetTimeStr, nil) - tx.Set(fmt.Sprintf(keyChannelTopicSetBy, channelKey), channelInfo.TopicSetBy, nil) - } - - if includeFlags&IncludeModes != 0 { - tx.Set(fmt.Sprintf(keyChannelPassword, channelKey), channelInfo.Key, nil) - modeString := modes.Modes(channelInfo.Modes).String() - tx.Set(fmt.Sprintf(keyChannelModes, channelKey), modeString, nil) - tx.Set(fmt.Sprintf(keyChannelUserLimit, channelKey), strconv.Itoa(channelInfo.UserLimit), nil) - tx.Set(fmt.Sprintf(keyChannelForward, channelKey), channelInfo.Forward, nil) - } - - if includeFlags&IncludeLists != 0 { - banlistString, _ := json.Marshal(channelInfo.Bans) - tx.Set(fmt.Sprintf(keyChannelBanlist, channelKey), string(banlistString), nil) - exceptlistString, _ := json.Marshal(channelInfo.Excepts) - tx.Set(fmt.Sprintf(keyChannelExceptlist, channelKey), string(exceptlistString), nil) - invitelistString, _ := json.Marshal(channelInfo.Invites) - tx.Set(fmt.Sprintf(keyChannelInvitelist, channelKey), string(invitelistString), nil) - accountToUModeString, _ := json.Marshal(channelInfo.AccountToUMode) - tx.Set(fmt.Sprintf(keyChannelAccountToUMode, channelKey), string(accountToUModeString), nil) - } - - if includeFlags&IncludeSettings != 0 { - settingsString, _ := json.Marshal(channelInfo.Settings) - tx.Set(fmt.Sprintf(keyChannelSettings, channelKey), string(settingsString), nil) - } -} - -// PurgeChannel records a channel purge. -func (reg *ChannelRegistry) PurgeChannel(chname string, record ChannelPurgeRecord) (err error) { - serialized, err := json.Marshal(record) - if err != nil { - return err - } - serializedStr := string(serialized) - key := fmt.Sprintf(keyChannelPurged, chname) - - return reg.server.store.Update(func(tx *buntdb.Tx) error { - tx.Set(key, serializedStr, nil) - return nil - }) -} - -// LoadPurgeRecord retrieves information about whether and how a channel was purged. -func (reg *ChannelRegistry) LoadPurgeRecord(chname string) (record ChannelPurgeRecord, err error) { - var rawRecord string - key := fmt.Sprintf(keyChannelPurged, chname) - reg.server.store.View(func(tx *buntdb.Tx) error { - rawRecord, _ = tx.Get(key) - return nil - }) - if rawRecord == "" { - err = errNoSuchChannel - return - } - err = json.Unmarshal([]byte(rawRecord), &record) - if err != nil { - reg.server.logger.Error("internal", "corrupt purge record", chname, err.Error()) - err = errNoSuchChannel - return - } - return -} - -// UnpurgeChannel deletes the record of a channel purge. -func (reg *ChannelRegistry) UnpurgeChannel(chname string) (err error) { - key := fmt.Sprintf(keyChannelPurged, chname) - return reg.server.store.Update(func(tx *buntdb.Tx) error { - tx.Delete(key) - return nil - }) +func (c *ChannelPurgeRecord) Deserialize(b []byte) error { + return json.Unmarshal(b, c) } diff --git a/irc/chanserv.go b/irc/chanserv.go index c4b75974..6e8389ab 100644 --- a/irc/chanserv.go +++ b/irc/chanserv.go @@ -459,7 +459,7 @@ func csRegisterHandler(service *ircService, server *Server, client *Client, comm // check whether a client has already registered too many channels func checkChanLimit(service *ircService, client *Client, rb *ResponseBuffer) (ok bool) { account := client.Account() - channelsAlreadyRegistered := client.server.accounts.ChannelsForAccount(account) + channelsAlreadyRegistered := client.server.channels.ChannelsForAccount(account) ok = len(channelsAlreadyRegistered) < client.server.Config().Channels.Registration.MaxChannelsPerAccount || client.HasRoleCapabs("chanreg") if !ok { service.Notice(rb, client.t("You have already registered the maximum number of channels; try dropping some with /CS UNREGISTER")) @@ -496,8 +496,8 @@ func csUnregisterHandler(service *ircService, server *Server, client *Client, co return } - info := channel.ExportRegistration(0) - channelKey := info.NameCasefolded + info := channel.exportSummary() + channelKey := channel.NameCasefolded() if !csPrivsCheck(service, info, client, rb) { return } @@ -519,7 +519,7 @@ func csClearHandler(service *ircService, server *Server, client *Client, command service.Notice(rb, client.t("Channel does not exist")) return } - if !csPrivsCheck(service, channel.ExportRegistration(0), client, rb) { + if !csPrivsCheck(service, channel.exportSummary(), client, rb) { return } @@ -550,7 +550,7 @@ func csTransferHandler(service *ircService, server *Server, client *Client, comm service.Notice(rb, client.t("Channel does not exist")) return } - regInfo := channel.ExportRegistration(0) + regInfo := channel.exportSummary() chname = regInfo.Name account := client.Account() isFounder := account != "" && account == regInfo.Founder @@ -729,11 +729,6 @@ func csPurgeListHandler(service *ircService, client *Client, rb *ResponseBuffer) } func csListHandler(service *ircService, server *Server, client *Client, command string, params []string, rb *ResponseBuffer) { - if !client.HasRoleCapabs("chanreg") { - service.Notice(rb, client.t("Insufficient privileges")) - return - } - var searchRegex *regexp.Regexp if len(params) > 0 { var err error @@ -746,7 +741,7 @@ func csListHandler(service *ircService, server *Server, client *Client, command service.Notice(rb, ircfmt.Unescape(client.t("*** $bChanServ LIST$b ***"))) - channels := server.channelRegistry.AllChannels() + channels := server.channels.AllRegisteredChannels() for _, channel := range channels { if searchRegex == nil || searchRegex.MatchString(channel) { service.Notice(rb, fmt.Sprintf(" %s", channel)) @@ -771,7 +766,7 @@ func csInfoHandler(service *ircService, server *Server, client *Client, command // purge status if client.HasRoleCapabs("chanreg") { - purgeRecord, err := server.channelRegistry.LoadPurgeRecord(chname) + purgeRecord, err := server.channels.LoadPurgeRecord(chname) if err == nil { service.Notice(rb, fmt.Sprintf(client.t("Channel %s was purged by the server operators and cannot be used"), chname)) service.Notice(rb, fmt.Sprintf(client.t("Purged by operator: %s"), purgeRecord.Oper)) @@ -789,13 +784,7 @@ func csInfoHandler(service *ircService, server *Server, client *Client, command var chinfo RegisteredChannel channel := server.channels.Get(params[0]) if channel != nil { - chinfo = channel.ExportRegistration(0) - } else { - chinfo, err = server.channelRegistry.LoadChannel(chname) - if err != nil && !(err == errNoSuchChannel || err == errFeatureDisabled) { - service.Notice(rb, client.t("An error occurred")) - return - } + chinfo = channel.exportSummary() } // channel exists but is unregistered, or doesn't exist: @@ -835,12 +824,12 @@ func csGetHandler(service *ircService, server *Server, client *Client, command s service.Notice(rb, client.t("No such channel")) return } - info := channel.ExportRegistration(IncludeSettings) + info := channel.exportSummary() if !csPrivsCheck(service, info, client, rb) { return } - displayChannelSetting(service, setting, info.Settings, client, rb) + displayChannelSetting(service, setting, channel.Settings(), client, rb) } func csSetHandler(service *ircService, server *Server, client *Client, command string, params []string, rb *ResponseBuffer) { @@ -850,12 +839,12 @@ func csSetHandler(service *ircService, server *Server, client *Client, command s service.Notice(rb, client.t("No such channel")) return } - info := channel.ExportRegistration(IncludeSettings) - settings := info.Settings + info := channel.exportSummary() if !csPrivsCheck(service, info, client, rb) { return } + settings := channel.Settings() var err error switch strings.ToLower(setting) { case "history": diff --git a/irc/database.go b/irc/database.go index b066a12a..a20e269e 100644 --- a/irc/database.go +++ b/irc/database.go @@ -14,6 +14,8 @@ import ( "strings" "time" + "github.com/ergochat/ergo/irc/bunt" + "github.com/ergochat/ergo/irc/datastore" "github.com/ergochat/ergo/irc/modes" "github.com/ergochat/ergo/irc/utils" @@ -21,12 +23,19 @@ import ( ) const ( - // 'version' of the database schema - keySchemaVersion = "db.version" - // latest schema of the db - latestDbSchema = 22 + // TODO migrate metadata keys as well - keyCloakSecret = "crypto.cloak_secret" + // 'version' of the database schema + // latest schema of the db + latestDbSchema = 23 +) + +var ( + schemaVersionUUID = utils.UUID{0, 255, 85, 13, 212, 10, 191, 121, 245, 152, 142, 89, 97, 141, 219, 87} // AP9VDdQKv3n1mI5ZYY3bVw + cloakSecretUUID = utils.UUID{170, 214, 184, 208, 116, 181, 67, 75, 161, 23, 233, 16, 113, 251, 94, 229} // qta40HS1Q0uhF-kQcfte5Q + + keySchemaVersion = bunt.BuntKey(datastore.TableMetadata, schemaVersionUUID) + keyCloakSecret = bunt.BuntKey(datastore.TableMetadata, cloakSecretUUID) ) type SchemaChanger func(*Config, *buntdb.Tx) error @@ -99,10 +108,7 @@ func openDatabaseInternal(config *Config, allowAutoupgrade bool) (db *buntdb.DB, // read the current version string var version int err = db.View(func(tx *buntdb.Tx) (err error) { - vStr, err := tx.Get(keySchemaVersion) - if err == nil { - version, err = strconv.Atoi(vStr) - } + version, err = retrieveSchemaVersion(tx) return err }) if err != nil { @@ -130,6 +136,17 @@ func openDatabaseInternal(config *Config, allowAutoupgrade bool) (db *buntdb.DB, } } +func retrieveSchemaVersion(tx *buntdb.Tx) (version int, err error) { + if val, err := tx.Get(keySchemaVersion); err == nil { + return strconv.Atoi(val) + } + // legacy key: + if val, err := tx.Get("db.version"); err == nil { + return strconv.Atoi(val) + } + return 0, buntdb.ErrNotFound +} + func performAutoUpgrade(currentVersion int, config *Config) (err error) { path := config.Datastore.Path log.Printf("attempting to auto-upgrade schema from version %d to %d\n", currentVersion, latestDbSchema) @@ -167,8 +184,12 @@ func UpgradeDB(config *Config) (err error) { var version int err = store.Update(func(tx *buntdb.Tx) error { for { - vStr, _ := tx.Get(keySchemaVersion) - version, _ = strconv.Atoi(vStr) + if version == 0 { + version, err = retrieveSchemaVersion(tx) + if err != nil { + return err + } + } if version == latestDbSchema { // success! break @@ -183,11 +204,12 @@ func UpgradeDB(config *Config) (err error) { if err != nil { return err } - _, _, err = tx.Set(keySchemaVersion, strconv.Itoa(change.TargetVersion), nil) + version = change.TargetVersion + _, _, err = tx.Set(keySchemaVersion, strconv.Itoa(version), nil) if err != nil { return err } - log.Printf("successfully updated schema to version %d\n", change.TargetVersion) + log.Printf("successfully updated schema to version %d\n", version) } return nil }) @@ -198,19 +220,17 @@ func UpgradeDB(config *Config) (err error) { return err } -func LoadCloakSecret(db *buntdb.DB) (result string) { - db.View(func(tx *buntdb.Tx) error { - result, _ = tx.Get(keyCloakSecret) - return nil - }) - return +func LoadCloakSecret(dstore datastore.Datastore) (result string, err error) { + val, err := dstore.Get(datastore.TableMetadata, cloakSecretUUID) + if err != nil { + return + } + return string(val), nil } -func StoreCloakSecret(db *buntdb.DB, secret string) { - db.Update(func(tx *buntdb.Tx) error { - tx.Set(keyCloakSecret, secret, nil) - return nil - }) +func StoreCloakSecret(dstore datastore.Datastore, secret string) { + // TODO error checking + dstore.Set(datastore.TableMetadata, cloakSecretUUID, []byte(secret), time.Time{}) } func schemaChangeV1toV2(config *Config, tx *buntdb.Tx) error { @@ -1112,6 +1132,92 @@ func schemaChangeV21To22(config *Config, tx *buntdb.Tx) error { return nil } +// first phase of document-oriented database refactor: channels +func schemaChangeV22ToV23(config *Config, tx *buntdb.Tx) error { + keyChannelExists := "channel.exists " + var channelNames []string + tx.AscendGreaterOrEqual("", keyChannelExists, func(key, value string) bool { + if !strings.HasPrefix(key, keyChannelExists) { + return false + } + channelNames = append(channelNames, strings.TrimPrefix(key, keyChannelExists)) + return true + }) + for _, channelName := range channelNames { + channel, err := loadLegacyChannel(tx, channelName) + if err != nil { + log.Printf("error loading legacy channel %s: %v", channelName, err) + continue + } + channel.UUID = utils.GenerateUUIDv4() + newKey := bunt.BuntKey(datastore.TableChannels, channel.UUID) + j, err := json.Marshal(channel) + if err != nil { + log.Printf("error marshaling channel %s: %v", channelName, err) + continue + } + tx.Set(newKey, string(j), nil) + deleteLegacyChannel(tx, channelName) + } + + // purges + keyChannelPurged := "channel.purged " + var purgeKeys []string + var channelPurges []ChannelPurgeRecord + tx.AscendGreaterOrEqual("", keyChannelPurged, func(key, value string) bool { + if !strings.HasPrefix(key, keyChannelPurged) { + return false + } + purgeKeys = append(purgeKeys, key) + cfname := strings.TrimPrefix(key, keyChannelPurged) + var record ChannelPurgeRecord + err := json.Unmarshal([]byte(value), &record) + if err != nil { + log.Printf("error unmarshaling channel purge for %s: %v", cfname, err) + return true + } + record.NameCasefolded = cfname + record.UUID = utils.GenerateUUIDv4() + channelPurges = append(channelPurges, record) + return true + }) + for _, record := range channelPurges { + newKey := bunt.BuntKey(datastore.TableChannelPurges, record.UUID) + j, err := json.Marshal(record) + if err != nil { + log.Printf("error marshaling channel purge %s: %v", record.NameCasefolded, err) + continue + } + tx.Set(newKey, string(j), nil) + } + for _, purgeKey := range purgeKeys { + tx.Delete(purgeKey) + } + + // clean up denormalized account-to-channels mapping + keyAccountChannels := "account.channels " + var accountToChannels []string + tx.AscendGreaterOrEqual("", keyAccountChannels, func(key, value string) bool { + if !strings.HasPrefix(key, keyAccountChannels) { + return false + } + accountToChannels = append(accountToChannels, key) + return true + }) + for _, key := range accountToChannels { + tx.Delete(key) + } + + // migrate cloak secret + val, _ := tx.Get("crypto.cloak_secret") + tx.Set(keyCloakSecret, val, nil) + + // bump the legacy version key to mark the database as downgrade-incompatible + tx.Set("db.version", "23", nil) + + return nil +} + func getSchemaChange(initialVersion int) (result SchemaChange, ok bool) { for _, change := range allChanges { if initialVersion == change.InitialVersion { @@ -1227,4 +1333,9 @@ var allChanges = []SchemaChange{ TargetVersion: 22, Changer: schemaChangeV21To22, }, + { + InitialVersion: 22, + TargetVersion: 23, + Changer: schemaChangeV22ToV23, + }, } diff --git a/irc/datastore/datastore.go b/irc/datastore/datastore.go new file mode 100644 index 00000000..c9d40a1b --- /dev/null +++ b/irc/datastore/datastore.go @@ -0,0 +1,45 @@ +// Copyright (c) 2022 Shivaram Lingamneni +// released under the MIT license + +package datastore + +import ( + "time" + + "github.com/ergochat/ergo/irc/utils" +) + +type Table uint16 + +// XXX these are persisted and must remain stable; +// do not reorder, when deleting use _ to ensure that the deleted value is skipped +const ( + TableMetadata Table = iota + TableChannels + TableChannelPurges +) + +type KV struct { + UUID utils.UUID + Value []byte +} + +// A Datastore provides the following abstraction: +// 1. Tables, each keyed on a UUID (the implementation is free to merge +// the table name and the UUID into a single key as long as the rest of +// the contract can be satisfied). Table names are [a-z0-9_]+ +// 2. The ability to efficiently enumerate all uuid-value pairs in a table +// 3. Gets, sets, and deletes for individual (table, uuid) keys +type Datastore interface { + Backoff() time.Duration + + GetAll(table Table) ([]KV, error) + + // This is rarely used because it would typically lead to TOCTOU races + Get(table Table, key utils.UUID) (value []byte, err error) + + Set(table Table, key utils.UUID, value []byte, expiration time.Time) error + + // Note that deleting a nonexistent key is not considered an error + Delete(table Table, key utils.UUID) error +} diff --git a/irc/errors.go b/irc/errors.go index 34f7fcdb..5c1e4c9f 100644 --- a/irc/errors.go +++ b/irc/errors.go @@ -51,6 +51,7 @@ var ( errNoExistingBan = errors.New("Ban does not exist") errNoSuchChannel = errors.New(`No such channel`) errChannelPurged = errors.New(`This channel was purged by the server operators and cannot be used`) + errChannelPurgedAlready = errors.New(`This channel was already purged and cannot be purged again`) errConfusableIdentifier = errors.New("This identifier is confusable with one already in use") errInsufficientPrivs = errors.New("Insufficient privileges") errInvalidUsername = errors.New("Invalid username") diff --git a/irc/getters.go b/irc/getters.go index 78a39429..2746cdb5 100644 --- a/irc/getters.go +++ b/irc/getters.go @@ -638,3 +638,9 @@ func (channel *Channel) getAmode(cfaccount string) (result modes.Mode) { defer channel.stateMutex.RUnlock() return channel.accountToUMode[cfaccount] } + +func (channel *Channel) UUID() utils.UUID { + channel.stateMutex.RLock() + defer channel.stateMutex.RUnlock() + return channel.uuid +} diff --git a/irc/handlers.go b/irc/handlers.go index 05b713d2..b823c241 100644 --- a/irc/handlers.go +++ b/irc/handlers.go @@ -1718,7 +1718,7 @@ func listHandler(server *Server, client *Client, msg ircmsg.Message, rb *Respons clientIsOp := client.HasRoleCapabs("sajoin") if len(channels) == 0 { - for _, channel := range server.channels.Channels() { + for _, channel := range server.channels.ListableChannels() { if !clientIsOp && channel.flags.HasMode(modes.Secret) && !channel.hasClient(client) { continue } diff --git a/irc/hostserv.go b/irc/hostserv.go index d6b5a743..fa3f527f 100644 --- a/irc/hostserv.go +++ b/irc/hostserv.go @@ -193,6 +193,6 @@ func hsSetCloakSecretHandler(service *ircService, server *Server, client *Client service.Notice(rb, fmt.Sprintf(client.t("To confirm, run this command: %s"), fmt.Sprintf("/HS SETCLOAKSECRET %s %s", secret, expectedCode))) return } - StoreCloakSecret(server.store, secret) + StoreCloakSecret(server.dstore, secret) service.Notice(rb, client.t("Rotated the cloak secret; you must rehash or restart the server for it to take effect")) } diff --git a/irc/import.go b/irc/import.go index 2ebf88dd..1fb2da32 100644 --- a/irc/import.go +++ b/irc/import.go @@ -9,9 +9,13 @@ import ( "log" "os" "strconv" + "time" "github.com/tidwall/buntdb" + "github.com/ergochat/ergo/irc/bunt" + "github.com/ergochat/ergo/irc/datastore" + "github.com/ergochat/ergo/irc/modes" "github.com/ergochat/ergo/irc/utils" ) @@ -20,7 +24,7 @@ const ( // XXX instead of referencing, e.g., keyAccountExists, we should write in the string literal // (to ensure that no matter what code changes happen elsewhere, we're still producing a // db of the hardcoded version) - importDBSchemaVersion = 22 + importDBSchemaVersion = 23 ) type userImport struct { @@ -54,8 +58,8 @@ type databaseImport struct { Channels map[string]channelImport } -func serializeAmodes(raw map[string]string, validCfUsernames utils.HashSet[string]) (result []byte, err error) { - processed := make(map[string]int, len(raw)) +func convertAmodes(raw map[string]string, validCfUsernames utils.HashSet[string]) (result map[string]modes.Mode, err error) { + result = make(map[string]modes.Mode) for accountName, mode := range raw { if len(mode) != 1 { return nil, fmt.Errorf("invalid mode %s for account %s", mode, accountName) @@ -64,10 +68,9 @@ func serializeAmodes(raw map[string]string, validCfUsernames utils.HashSet[strin if err != nil || !validCfUsernames.Has(cfname) { log.Printf("skipping invalid amode recipient %s\n", accountName) } else { - processed[cfname] = int(mode[0]) + result[cfname] = modes.Mode(mode[0]) } } - result, err = json.Marshal(processed) return } @@ -147,8 +150,9 @@ func doImportDBGeneric(config *Config, dbImport databaseImport, credsType Creden cfUsernames.Add(cfUsername) } + // TODO fix this: for chname, chInfo := range dbImport.Channels { - cfchname, err := CasefoldChannel(chname) + _, err := CasefoldChannel(chname) if err != nil { log.Printf("invalid channel name %s: %v", chname, err) continue @@ -158,43 +162,42 @@ func doImportDBGeneric(config *Config, dbImport databaseImport, credsType Creden log.Printf("invalid founder %s for channel %s: %v", chInfo.Founder, chname, err) continue } - tx.Set(fmt.Sprintf(keyChannelExists, cfchname), "1", nil) - tx.Set(fmt.Sprintf(keyChannelName, cfchname), chname, nil) - tx.Set(fmt.Sprintf(keyChannelRegTime, cfchname), strconv.FormatInt(chInfo.RegisteredAt, 10), nil) - tx.Set(fmt.Sprintf(keyChannelFounder, cfchname), cffounder, nil) - accountChannelsKey := fmt.Sprintf(keyAccountChannels, cffounder) - founderChannels, fcErr := tx.Get(accountChannelsKey) - if fcErr != nil || founderChannels == "" { - founderChannels = cfchname - } else { - founderChannels = fmt.Sprintf("%s,%s", founderChannels, cfchname) - } - tx.Set(accountChannelsKey, founderChannels, nil) + var regInfo RegisteredChannel + regInfo.Name = chname + regInfo.UUID = utils.GenerateUUIDv4() + regInfo.Founder = cffounder + regInfo.RegisteredAt = time.Unix(0, chInfo.RegisteredAt).UTC() if chInfo.Topic != "" { - tx.Set(fmt.Sprintf(keyChannelTopic, cfchname), chInfo.Topic, nil) - tx.Set(fmt.Sprintf(keyChannelTopicSetTime, cfchname), strconv.FormatInt(chInfo.TopicSetAt, 10), nil) - tx.Set(fmt.Sprintf(keyChannelTopicSetBy, cfchname), chInfo.TopicSetBy, nil) + regInfo.Topic = chInfo.Topic + regInfo.TopicSetBy = chInfo.TopicSetBy + regInfo.TopicSetTime = time.Unix(0, chInfo.TopicSetAt).UTC() } + if len(chInfo.Amode) != 0 { - m, err := serializeAmodes(chInfo.Amode, cfUsernames) + m, err := convertAmodes(chInfo.Amode, cfUsernames) if err == nil { - tx.Set(fmt.Sprintf(keyChannelAccountToUMode, cfchname), string(m), nil) + regInfo.AccountToUMode = m } else { - log.Printf("couldn't serialize amodes for %s: %v", chname, err) + log.Printf("couldn't process amodes for %s: %v", chname, err) } } - tx.Set(fmt.Sprintf(keyChannelModes, cfchname), chInfo.Modes, nil) - if chInfo.Key != "" { - tx.Set(fmt.Sprintf(keyChannelPassword, cfchname), chInfo.Key, nil) + for _, mode := range chInfo.Modes { + regInfo.Modes = append(regInfo.Modes, modes.Mode(mode)) } + regInfo.Key = chInfo.Key if chInfo.Limit > 0 { - tx.Set(fmt.Sprintf(keyChannelUserLimit, cfchname), strconv.Itoa(chInfo.Limit), nil) + regInfo.UserLimit = chInfo.Limit } if chInfo.Forward != "" { if _, err := CasefoldChannel(chInfo.Forward); err == nil { - tx.Set(fmt.Sprintf(keyChannelForward, cfchname), chInfo.Forward, nil) + regInfo.Forward = chInfo.Forward } } + if j, err := json.Marshal(regInfo); err == nil { + tx.Set(bunt.BuntKey(datastore.TableChannels, regInfo.UUID), string(j), nil) + } else { + log.Printf("couldn't serialize channel %s: %v", chname, err) + } } if warnSkeletons { diff --git a/irc/legacy.go b/irc/legacy.go index 0a55d3ca..8cd95153 100644 --- a/irc/legacy.go +++ b/irc/legacy.go @@ -4,7 +4,15 @@ package irc import ( "encoding/base64" + "encoding/json" "errors" + "fmt" + "strconv" + "time" + + "github.com/tidwall/buntdb" + + "github.com/ergochat/ergo/irc/modes" ) var ( @@ -25,3 +33,116 @@ func decodeLegacyPasswordHash(hash string) ([]byte, error) { return nil, errInvalidPasswordHash } } + +// legacy channel registration code + +const ( + keyChannelExists = "channel.exists %s" + keyChannelName = "channel.name %s" // stores the 'preferred name' of the channel, not casemapped + keyChannelRegTime = "channel.registered.time %s" + keyChannelFounder = "channel.founder %s" + keyChannelTopic = "channel.topic %s" + keyChannelTopicSetBy = "channel.topic.setby %s" + keyChannelTopicSetTime = "channel.topic.settime %s" + keyChannelBanlist = "channel.banlist %s" + keyChannelExceptlist = "channel.exceptlist %s" + keyChannelInvitelist = "channel.invitelist %s" + keyChannelPassword = "channel.key %s" + keyChannelModes = "channel.modes %s" + keyChannelAccountToUMode = "channel.accounttoumode %s" + keyChannelUserLimit = "channel.userlimit %s" + keyChannelSettings = "channel.settings %s" + keyChannelForward = "channel.forward %s" + + keyChannelPurged = "channel.purged %s" +) + +func deleteLegacyChannel(tx *buntdb.Tx, nameCasefolded string) { + tx.Delete(fmt.Sprintf(keyChannelExists, nameCasefolded)) + tx.Delete(fmt.Sprintf(keyChannelName, nameCasefolded)) + tx.Delete(fmt.Sprintf(keyChannelRegTime, nameCasefolded)) + tx.Delete(fmt.Sprintf(keyChannelFounder, nameCasefolded)) + tx.Delete(fmt.Sprintf(keyChannelTopic, nameCasefolded)) + tx.Delete(fmt.Sprintf(keyChannelTopicSetBy, nameCasefolded)) + tx.Delete(fmt.Sprintf(keyChannelTopicSetTime, nameCasefolded)) + tx.Delete(fmt.Sprintf(keyChannelBanlist, nameCasefolded)) + tx.Delete(fmt.Sprintf(keyChannelExceptlist, nameCasefolded)) + tx.Delete(fmt.Sprintf(keyChannelInvitelist, nameCasefolded)) + tx.Delete(fmt.Sprintf(keyChannelPassword, nameCasefolded)) + tx.Delete(fmt.Sprintf(keyChannelModes, nameCasefolded)) + tx.Delete(fmt.Sprintf(keyChannelAccountToUMode, nameCasefolded)) + tx.Delete(fmt.Sprintf(keyChannelUserLimit, nameCasefolded)) + tx.Delete(fmt.Sprintf(keyChannelSettings, nameCasefolded)) + tx.Delete(fmt.Sprintf(keyChannelForward, nameCasefolded)) +} + +func loadLegacyChannel(tx *buntdb.Tx, nameCasefolded string) (info RegisteredChannel, err error) { + channelKey := nameCasefolded + // nice to have: do all JSON (de)serialization outside of the buntdb transaction + _, dberr := tx.Get(fmt.Sprintf(keyChannelExists, channelKey)) + if dberr == buntdb.ErrNotFound { + // chan does not already exist, return + err = errNoSuchChannel + return + } + + // channel exists, load it + name, _ := tx.Get(fmt.Sprintf(keyChannelName, channelKey)) + regTime, _ := tx.Get(fmt.Sprintf(keyChannelRegTime, channelKey)) + regTimeInt, _ := strconv.ParseInt(regTime, 10, 64) + founder, _ := tx.Get(fmt.Sprintf(keyChannelFounder, channelKey)) + topic, _ := tx.Get(fmt.Sprintf(keyChannelTopic, channelKey)) + topicSetBy, _ := tx.Get(fmt.Sprintf(keyChannelTopicSetBy, channelKey)) + var topicSetTime time.Time + topicSetTimeStr, _ := tx.Get(fmt.Sprintf(keyChannelTopicSetTime, channelKey)) + if topicSetTimeInt, topicSetTimeErr := strconv.ParseInt(topicSetTimeStr, 10, 64); topicSetTimeErr == nil { + topicSetTime = time.Unix(0, topicSetTimeInt).UTC() + } + password, _ := tx.Get(fmt.Sprintf(keyChannelPassword, channelKey)) + modeString, _ := tx.Get(fmt.Sprintf(keyChannelModes, channelKey)) + userLimitString, _ := tx.Get(fmt.Sprintf(keyChannelUserLimit, channelKey)) + forward, _ := tx.Get(fmt.Sprintf(keyChannelForward, channelKey)) + banlistString, _ := tx.Get(fmt.Sprintf(keyChannelBanlist, channelKey)) + exceptlistString, _ := tx.Get(fmt.Sprintf(keyChannelExceptlist, channelKey)) + invitelistString, _ := tx.Get(fmt.Sprintf(keyChannelInvitelist, channelKey)) + accountToUModeString, _ := tx.Get(fmt.Sprintf(keyChannelAccountToUMode, channelKey)) + settingsString, _ := tx.Get(fmt.Sprintf(keyChannelSettings, channelKey)) + + modeSlice := make([]modes.Mode, len(modeString)) + for i, mode := range modeString { + modeSlice[i] = modes.Mode(mode) + } + + userLimit, _ := strconv.Atoi(userLimitString) + + var banlist map[string]MaskInfo + _ = json.Unmarshal([]byte(banlistString), &banlist) + var exceptlist map[string]MaskInfo + _ = json.Unmarshal([]byte(exceptlistString), &exceptlist) + var invitelist map[string]MaskInfo + _ = json.Unmarshal([]byte(invitelistString), &invitelist) + accountToUMode := make(map[string]modes.Mode) + _ = json.Unmarshal([]byte(accountToUModeString), &accountToUMode) + + var settings ChannelSettings + _ = json.Unmarshal([]byte(settingsString), &settings) + + info = RegisteredChannel{ + Name: name, + RegisteredAt: time.Unix(0, regTimeInt).UTC(), + Founder: founder, + Topic: topic, + TopicSetBy: topicSetBy, + TopicSetTime: topicSetTime, + Key: password, + Modes: modeSlice, + Bans: banlist, + Excepts: exceptlist, + Invites: invitelist, + AccountToUMode: accountToUMode, + UserLimit: int(userLimit), + Settings: settings, + Forward: forward, + } + return info, nil +} diff --git a/irc/nickserv.go b/irc/nickserv.go index 3c6cc1d6..f4f8b3ec 100644 --- a/irc/nickserv.go +++ b/irc/nickserv.go @@ -954,9 +954,9 @@ func nsInfoHandler(service *ircService, server *Server, client *Client, command func listRegisteredChannels(service *ircService, accountName string, rb *ResponseBuffer) { client := rb.session.client - channels := client.server.accounts.ChannelsForAccount(accountName) + channels := client.server.channels.ChannelsForAccount(accountName) service.Notice(rb, fmt.Sprintf(client.t("Account %s has %d registered channel(s)."), accountName, len(channels))) - for _, channel := range rb.session.client.server.accounts.ChannelsForAccount(accountName) { + for _, channel := range channels { service.Notice(rb, fmt.Sprintf(client.t("Registered channel: %s"), channel)) } } diff --git a/irc/serde.go b/irc/serde.go new file mode 100644 index 00000000..c40ee148 --- /dev/null +++ b/irc/serde.go @@ -0,0 +1,37 @@ +// Copyright (c) 2022 Shivaram Lingamneni +// released under the MIT license + +package irc + +import ( + "strconv" + + "github.com/ergochat/ergo/irc/datastore" + "github.com/ergochat/ergo/irc/logger" +) + +type Serializable interface { + Serialize() ([]byte, error) + Deserialize([]byte) error +} + +func FetchAndDeserializeAll[T any, C interface { + *T + Serializable +}](table datastore.Table, dstore datastore.Datastore, log *logger.Manager) (result []T, err error) { + rawRecords, err := dstore.GetAll(table) + if err != nil { + return + } + result = make([]T, len(rawRecords)) + pos := 0 + for _, record := range rawRecords { + err := C(&result[pos]).Deserialize(record.Value) + if err != nil { + log.Error("internal", "deserialization error", strconv.Itoa(int(table)), record.UUID.String(), err.Error()) + continue + } + pos++ + } + return result[:pos], nil +} diff --git a/irc/server.go b/irc/server.go index 1a7e1abd..e1f3805c 100644 --- a/irc/server.go +++ b/irc/server.go @@ -22,9 +22,12 @@ import ( "github.com/ergochat/irc-go/ircfmt" "github.com/okzk/sdnotify" + "github.com/tidwall/buntdb" + "github.com/ergochat/ergo/irc/bunt" "github.com/ergochat/ergo/irc/caps" "github.com/ergochat/ergo/irc/connection_limits" + "github.com/ergochat/ergo/irc/datastore" "github.com/ergochat/ergo/irc/flatip" "github.com/ergochat/ergo/irc/flock" "github.com/ergochat/ergo/irc/history" @@ -33,7 +36,6 @@ import ( "github.com/ergochat/ergo/irc/mysql" "github.com/ergochat/ergo/irc/sno" "github.com/ergochat/ergo/irc/utils" - "github.com/tidwall/buntdb" ) const ( @@ -66,7 +68,6 @@ type Server struct { accepts AcceptManager accounts AccountManager channels ChannelManager - channelRegistry ChannelRegistry clients ClientManager config atomic.Pointer[Config] configFilename string @@ -87,6 +88,7 @@ type Server struct { tracebackSignal chan os.Signal snomasks SnoManager store *buntdb.DB + dstore datastore.Datastore historyDB mysql.MySQL torLimiter connection_limits.TorLimiter whoWas WhoWasList @@ -98,6 +100,10 @@ type Server struct { // NewServer returns a new Oragono server. func NewServer(config *Config, logger *logger.Manager) (*Server, error) { + // sanity check that kernel randomness is available; on modern Linux, + // this will block until it is, on other platforms it may panic: + utils.GenerateUUIDv4() + // initialize data structures server := &Server{ ctime: time.Now().UTC(), @@ -716,7 +722,11 @@ func (server *Server) applyConfig(config *Config) (err error) { // now that the datastore is initialized, we can load the cloak secret from it // XXX this modifies config after the initial load, which is naughty, // but there's no data race because we haven't done SetConfig yet - config.Server.Cloaks.SetSecret(LoadCloakSecret(server.store)) + cloakSecret, err := LoadCloakSecret(server.dstore) + if err != nil { + return fmt.Errorf("Could not load cloak secret: %w", err) + } + config.Server.Cloaks.SetSecret(cloakSecret) // activate the new config server.config.Store(config) @@ -837,6 +847,7 @@ func (server *Server) loadDatastore(config *Config) error { db, err := OpenDatabase(config) if err == nil { server.store = db + server.dstore = bunt.NewBuntdbDatastore(db, server.logger) return nil } else { return fmt.Errorf("Failed to open datastore: %s", err.Error()) @@ -849,8 +860,7 @@ func (server *Server) loadFromDatastore(config *Config) (err error) { server.loadDLines() server.loadKLines() - server.channelRegistry.Initialize(server) - server.channels.Initialize(server) + server.channels.Initialize(server, config) server.accounts.Initialize(server) if config.Datastore.MySQL.Enabled { diff --git a/irc/utils/uuid.go b/irc/utils/uuid.go new file mode 100644 index 00000000..254dff53 --- /dev/null +++ b/irc/utils/uuid.go @@ -0,0 +1,56 @@ +// Copyright (c) 2022 Shivaram Lingamneni +// released under the MIT license + +package utils + +import ( + "crypto/rand" + "encoding/base64" + "errors" +) + +var ( + ErrInvalidUUID = errors.New("Invalid uuid") +) + +// Technically a UUIDv4 has version bits set, but this doesn't matter in practice +type UUID [16]byte + +func (u UUID) MarshalJSON() (b []byte, err error) { + b = make([]byte, 24) + b[0] = '"' + base64.RawURLEncoding.Encode(b[1:], u[:]) + b[23] = '"' + return +} + +func (u *UUID) UnmarshalJSON(b []byte) (err error) { + if len(b) != 24 { + return ErrInvalidUUID + } + readLen, err := base64.RawURLEncoding.Decode(u[:], b[1:23]) + if readLen != 16 { + return ErrInvalidUUID + } + return nil +} + +func (u *UUID) String() string { + return base64.RawURLEncoding.EncodeToString(u[:]) +} + +func GenerateUUIDv4() (result UUID) { + _, err := rand.Read(result[:]) + if err != nil { + panic(err) + } + return +} + +func DecodeUUID(ustr string) (result UUID, err error) { + length, err := base64.RawURLEncoding.Decode(result[:], []byte(ustr)) + if err == nil && length != 16 { + err = ErrInvalidUUID + } + return +}