1
0
Fork 0
forked from External/ergo
This commit is contained in:
Shivaram Lingamneni 2020-05-12 12:05:40 -04:00
parent 21958768d8
commit 67f35e5c8a
16 changed files with 819 additions and 108 deletions

View file

@ -18,5 +18,6 @@ type Config struct {
Timeout time.Duration
// XXX these are copied from elsewhere in the config:
ExpireTime time.Duration
ExpireTime time.Duration
TrackAccountMessages bool
}

View file

@ -7,7 +7,10 @@ import (
"bytes"
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"io"
"runtime/debug"
"sync"
"sync/atomic"
@ -19,6 +22,10 @@ import (
"github.com/oragono/oragono/irc/utils"
)
var (
ErrDisallowed = errors.New("disallowed")
)
const (
// maximum length in bytes of any message target (nickname or channel name) in its
// canonicalized (i.e., casefolded) state:
@ -27,30 +34,46 @@ const (
// latest schema of the db
latestDbSchema = "2"
keySchemaVersion = "db.version"
cleanupRowLimit = 50
cleanupPauseTime = 10 * time.Minute
// minor version indicates rollback-safe upgrades, i.e.,
// you can downgrade oragono and everything will work
latestDbMinorVersion = "1"
keySchemaMinorVersion = "db.minorversion"
cleanupRowLimit = 50
cleanupPauseTime = 10 * time.Minute
)
type MySQL struct {
timeout int64
db *sql.DB
logger *logger.Manager
type e struct{}
insertHistory *sql.Stmt
insertSequence *sql.Stmt
insertConversation *sql.Stmt
type MySQL struct {
timeout int64
trackAccountMessages uint32
db *sql.DB
logger *logger.Manager
insertHistory *sql.Stmt
insertSequence *sql.Stmt
insertConversation *sql.Stmt
insertAccountMessage *sql.Stmt
stateMutex sync.Mutex
config Config
wakeForgetter chan e
}
func (mysql *MySQL) Initialize(logger *logger.Manager, config Config) {
mysql.logger = logger
mysql.wakeForgetter = make(chan e, 1)
mysql.SetConfig(config)
}
func (mysql *MySQL) SetConfig(config Config) {
atomic.StoreInt64(&mysql.timeout, int64(config.Timeout))
var trackAccountMessages uint32
if config.TrackAccountMessages {
trackAccountMessages = 1
}
atomic.StoreUint32(&mysql.trackAccountMessages, trackAccountMessages)
mysql.stateMutex.Lock()
mysql.config = config
mysql.stateMutex.Unlock()
@ -85,6 +108,7 @@ func (m *MySQL) Open() (err error) {
}
go m.cleanupLoop()
go m.forgetLoop()
return nil
}
@ -109,14 +133,35 @@ func (mysql *MySQL) fixSchemas() (err error) {
if err != nil {
return
}
_, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaMinorVersion, latestDbMinorVersion)
if err != nil {
return
}
return
} else if err == nil && schema != latestDbSchema {
// TODO figure out what to do about schema changes
return &utils.IncompatibleSchemaError{CurrentVersion: schema, RequiredVersion: latestDbSchema}
} else {
} else if err != nil {
return err
}
return nil
var minorVersion string
err = mysql.db.QueryRow(`select value from metadata where key_name = ?;`, keySchemaMinorVersion).Scan(&minorVersion)
if err == sql.ErrNoRows {
// XXX for now, the only minor version upgrade is the account tracking tables
err = mysql.createComplianceTables()
if err != nil {
return
}
_, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaMinorVersion, latestDbMinorVersion)
if err != nil {
return
}
} else if err == nil && minorVersion != latestDbMinorVersion {
// TODO: if minorVersion < latestDbMinorVersion, upgrade,
// if latestDbMinorVersion < minorVersion, ignore because backwards compatible
}
return
}
func (mysql *MySQL) createTables() (err error) {
@ -155,6 +200,32 @@ func (mysql *MySQL) createTables() (err error) {
return err
}
err = mysql.createComplianceTables()
if err != nil {
return err
}
return nil
}
func (mysql *MySQL) createComplianceTables() (err error) {
_, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE account_messages (
history_id BIGINT UNSIGNED NOT NULL PRIMARY KEY,
account VARBINARY(%[1]d) NOT NULL,
KEY (account, history_id)
) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
if err != nil {
return err
}
_, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE forget (
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
account VARBINARY(%[1]d) NOT NULL
) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
if err != nil {
return err
}
return nil
}
@ -191,7 +262,10 @@ func (mysql *MySQL) cleanupLoop() {
}
func (mysql *MySQL) doCleanup(age time.Duration) (count int, err error) {
ids, maxNanotime, err := mysql.selectCleanupIDs(age)
ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
defer cancel()
ids, maxNanotime, err := mysql.selectCleanupIDs(ctx, age)
if len(ids) == 0 {
mysql.logger.Debug("mysql", "found no rows to clean up")
return
@ -199,6 +273,10 @@ func (mysql *MySQL) doCleanup(age time.Duration) (count int, err error) {
mysql.logger.Debug("mysql", fmt.Sprintf("deleting %d history rows, max age %s", len(ids), utils.NanoToTimestamp(maxNanotime)))
return len(ids), mysql.deleteHistoryIDs(ctx, ids)
}
func (mysql *MySQL) deleteHistoryIDs(ctx context.Context, ids []uint64) (err error) {
// can't use ? binding for a variable number of arguments, build the IN clause manually
var inBuf bytes.Buffer
inBuf.WriteByte('(')
@ -210,25 +288,30 @@ func (mysql *MySQL) doCleanup(age time.Duration) (count int, err error) {
}
inBuf.WriteRune(')')
_, err = mysql.db.Exec(fmt.Sprintf(`DELETE FROM conversations WHERE history_id in %s;`, inBuf.Bytes()))
_, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM conversations WHERE history_id in %s;`, inBuf.Bytes()))
if err != nil {
return
}
_, err = mysql.db.Exec(fmt.Sprintf(`DELETE FROM sequence WHERE history_id in %s;`, inBuf.Bytes()))
_, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM sequence WHERE history_id in %s;`, inBuf.Bytes()))
if err != nil {
return
}
_, err = mysql.db.Exec(fmt.Sprintf(`DELETE FROM history WHERE id in %s;`, inBuf.Bytes()))
if mysql.isTrackingAccountMessages() {
_, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM account_messages WHERE history_id in %s;`, inBuf.Bytes()))
if err != nil {
return
}
}
_, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM history WHERE id in %s;`, inBuf.Bytes()))
if err != nil {
return
}
count = len(ids)
return
}
func (mysql *MySQL) selectCleanupIDs(age time.Duration) (ids []uint64, maxNanotime int64, err error) {
rows, err := mysql.db.Query(`
func (mysql *MySQL) selectCleanupIDs(ctx context.Context, age time.Duration) (ids []uint64, maxNanotime int64, err error) {
rows, err := mysql.db.QueryContext(ctx, `
SELECT history.id, sequence.nanotime
FROM history
LEFT JOIN sequence ON history.id = sequence.history_id
@ -266,6 +349,109 @@ func (mysql *MySQL) selectCleanupIDs(age time.Duration) (ids []uint64, maxNanoti
return
}
// wait for forget queue items and process them one by one
func (mysql *MySQL) forgetLoop() {
defer func() {
if r := recover(); r != nil {
mysql.logger.Error("mysql",
fmt.Sprintf("Panic in forget routine: %v\n%s", r, debug.Stack()))
time.Sleep(cleanupPauseTime)
go mysql.forgetLoop()
}
}()
for {
for {
found, err := mysql.doForget()
mysql.logError("error processing forget", err)
if err != nil {
time.Sleep(cleanupPauseTime)
}
if !found {
break
}
}
<-mysql.wakeForgetter
}
}
// dequeue an item from the forget queue and process it
func (mysql *MySQL) doForget() (found bool, err error) {
id, account, err := func() (id int64, account string, err error) {
ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
defer cancel()
row := mysql.db.QueryRowContext(ctx,
`SELECT forget.id, forget.account FROM forget LIMIT 1;`)
err = row.Scan(&id, &account)
if err == sql.ErrNoRows {
return 0, "", nil
}
return
}()
if err != nil || account == "" {
return false, err
}
found = true
var count int
for {
start := time.Now()
count, err = mysql.doForgetIteration(account)
elapsed := time.Since(start)
if err != nil {
return true, err
}
if count == 0 {
break
}
time.Sleep(elapsed)
}
mysql.logger.Debug("mysql", "forget complete for account", account)
ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
defer cancel()
_, err = mysql.db.ExecContext(ctx, `DELETE FROM forget where id = ?;`, id)
return
}
func (mysql *MySQL) doForgetIteration(account string) (count int, err error) {
ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
defer cancel()
rows, err := mysql.db.QueryContext(ctx, `
SELECT account_messages.history_id
FROM account_messages
WHERE account_messages.account = ?
LIMIT ?;`, account, cleanupRowLimit)
if err != nil {
return
}
defer rows.Close()
var ids []uint64
for rows.Next() {
var id uint64
err = rows.Scan(&id)
if err != nil {
return
}
ids = append(ids, id)
}
if len(ids) == 0 {
return
}
mysql.logger.Debug("mysql", fmt.Sprintf("deleting %d history rows from account %s", len(ids), account))
err = mysql.deleteHistoryIDs(ctx, ids)
return len(ids), err
}
func (mysql *MySQL) prepareStatements() (err error) {
mysql.insertHistory, err = mysql.db.Prepare(`INSERT INTO history
(data, msgid) VALUES (?, ?);`)
@ -282,6 +468,11 @@ func (mysql *MySQL) prepareStatements() (err error) {
if err != nil {
return
}
mysql.insertAccountMessage, err = mysql.db.Prepare(`INSERT INTO account_messages
(history_id, account) VALUES (?, ?);`)
if err != nil {
return
}
return
}
@ -290,6 +481,10 @@ func (mysql *MySQL) getTimeout() time.Duration {
return time.Duration(atomic.LoadInt64(&mysql.timeout))
}
func (mysql *MySQL) isTrackingAccountMessages() bool {
return atomic.LoadUint32(&mysql.trackAccountMessages) != 0
}
func (mysql *MySQL) logError(context string, err error) (quit bool) {
if err != nil {
mysql.logger.Error("mysql", context, err.Error())
@ -298,7 +493,27 @@ func (mysql *MySQL) logError(context string, err error) (quit bool) {
return false
}
func (mysql *MySQL) AddChannelItem(target string, item history.Item) (err error) {
func (mysql *MySQL) Forget(account string) {
if mysql.db == nil || account == "" {
return
}
ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
defer cancel()
_, err := mysql.db.ExecContext(ctx, `INSERT INTO forget (account) VALUES (?);`, account)
if mysql.logError("can't insert into forget table", err) {
return
}
// wake up the forget goroutine if it's blocked:
select {
case mysql.wakeForgetter <- e{}:
default:
}
}
func (mysql *MySQL) AddChannelItem(target string, item history.Item, account string) (err error) {
if mysql.db == nil {
return
}
@ -316,6 +531,15 @@ func (mysql *MySQL) AddChannelItem(target string, item history.Item) (err error)
}
err = mysql.insertSequenceEntry(ctx, target, item.Message.Time.UnixNano(), id)
if err != nil {
return
}
err = mysql.insertAccountMessageEntry(ctx, id, account)
if err != nil {
return
}
return
}
@ -354,6 +578,15 @@ func (mysql *MySQL) insertBase(ctx context.Context, item history.Item) (id int64
return
}
func (mysql *MySQL) insertAccountMessageEntry(ctx context.Context, id int64, account string) (err error) {
if account == "" || !mysql.isTrackingAccountMessages() {
return
}
_, err = mysql.insertAccountMessage.ExecContext(ctx, id, account)
mysql.logError("could not insert account-message entry", err)
return
}
func (mysql *MySQL) AddDirectMessage(sender, senderAccount, recipient, recipientAccount string, item history.Item) (err error) {
if mysql.db == nil {
return
@ -399,10 +632,102 @@ func (mysql *MySQL) AddDirectMessage(sender, senderAccount, recipient, recipient
}
}
err = mysql.insertAccountMessageEntry(ctx, id, senderAccount)
if err != nil {
return
}
return
}
func (mysql *MySQL) msgidToTime(ctx context.Context, msgid string) (result time.Time, err error) {
// note that accountName is the unfolded name
func (mysql *MySQL) DeleteMsgid(msgid, accountName string) (err error) {
if mysql.db == nil {
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
defer cancel()
_, id, data, err := mysql.lookupMsgid(ctx, msgid, true)
if err != nil {
return
}
if accountName != "*" {
var item history.Item
err = unmarshalItem(data, &item)
// delete if the entry is corrupt
if err == nil && item.AccountName != accountName {
return ErrDisallowed
}
}
err = mysql.deleteHistoryIDs(ctx, []uint64{id})
mysql.logError("couldn't delete msgid", err)
return
}
func (mysql *MySQL) Export(account string, writer io.Writer) {
if mysql.db == nil {
return
}
var err error
var lastSeen uint64
for {
rows := func() (count int) {
ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
defer cancel()
rows, rowsErr := mysql.db.QueryContext(ctx, `
SELECT account_messages.history_id, history.data, sequence.target FROM account_messages
INNER JOIN history ON history.id = account_messages.history_id
INNER JOIN sequence ON account_messages.history_id = sequence.history_id
WHERE account_messages.account = ? AND account_messages.history_id > ?
LIMIT ?`, account, lastSeen, cleanupRowLimit)
if rowsErr != nil {
err = rowsErr
return
}
defer rows.Close()
for rows.Next() {
var id uint64
var blob, jsonBlob []byte
var target string
var item history.Item
err = rows.Scan(&id, &blob, &target)
if err != nil {
return
}
err = unmarshalItem(blob, &item)
if err != nil {
return
}
item.CfCorrespondent = target
jsonBlob, err = json.Marshal(item)
if err != nil {
return
}
count++
if lastSeen < id {
lastSeen = id
}
writer.Write(jsonBlob)
writer.Write([]byte{'\n'})
}
return
}()
if rows == 0 || err != nil {
break
}
}
mysql.logError("could not export history", err)
return
}
func (mysql *MySQL) lookupMsgid(ctx context.Context, msgid string, includeData bool) (result time.Time, id uint64, data []byte, err error) {
// in theory, we could optimize out a roundtrip to the database by using a subquery instead:
// sequence.nanotime > (
// SELECT sequence.nanotime FROM sequence, history
@ -415,15 +740,27 @@ func (mysql *MySQL) msgidToTime(ctx context.Context, msgid string) (result time.
if err != nil {
return
}
row := mysql.db.QueryRowContext(ctx, `
SELECT sequence.nanotime FROM sequence
cols := `sequence.nanotime`
if includeData {
cols = `sequence.nanotime, sequence.history_id, history.data`
}
row := mysql.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT %s FROM sequence
INNER JOIN history ON history.id = sequence.history_id
WHERE history.msgid = ? LIMIT 1;`, decoded)
WHERE history.msgid = ? LIMIT 1;`, cols), decoded)
var nanotime int64
err = row.Scan(&nanotime)
if mysql.logError("could not resolve msgid to time", err) {
if !includeData {
err = row.Scan(&nanotime)
} else {
err = row.Scan(&nanotime, &id, &data)
}
if err != sql.ErrNoRows {
mysql.logError("could not resolve msgid to time", err)
}
if err != nil {
return
}
result = time.Unix(0, nanotime).UTC()
return
}
@ -519,14 +856,14 @@ func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (
startTime := start.Time
if start.Msgid != "" {
startTime, err = s.mysql.msgidToTime(ctx, start.Msgid)
startTime, _, _, err = s.mysql.lookupMsgid(ctx, start.Msgid, false)
if err != nil {
return nil, false, err
}
}
endTime := end.Time
if end.Msgid != "" {
endTime, err = s.mysql.msgidToTime(ctx, end.Msgid)
endTime, _, _, err = s.mysql.lookupMsgid(ctx, end.Msgid, false)
if err != nil {
return nil, false, err
}