forked from External/ergo
parent
21958768d8
commit
67f35e5c8a
16 changed files with 819 additions and 108 deletions
|
|
@ -18,5 +18,6 @@ type Config struct {
|
|||
Timeout time.Duration
|
||||
|
||||
// XXX these are copied from elsewhere in the config:
|
||||
ExpireTime time.Duration
|
||||
ExpireTime time.Duration
|
||||
TrackAccountMessages bool
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,7 +7,10 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
|
@ -19,6 +22,10 @@ import (
|
|||
"github.com/oragono/oragono/irc/utils"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDisallowed = errors.New("disallowed")
|
||||
)
|
||||
|
||||
const (
|
||||
// maximum length in bytes of any message target (nickname or channel name) in its
|
||||
// canonicalized (i.e., casefolded) state:
|
||||
|
|
@ -27,30 +34,46 @@ const (
|
|||
// latest schema of the db
|
||||
latestDbSchema = "2"
|
||||
keySchemaVersion = "db.version"
|
||||
cleanupRowLimit = 50
|
||||
cleanupPauseTime = 10 * time.Minute
|
||||
// minor version indicates rollback-safe upgrades, i.e.,
|
||||
// you can downgrade oragono and everything will work
|
||||
latestDbMinorVersion = "1"
|
||||
keySchemaMinorVersion = "db.minorversion"
|
||||
cleanupRowLimit = 50
|
||||
cleanupPauseTime = 10 * time.Minute
|
||||
)
|
||||
|
||||
type MySQL struct {
|
||||
timeout int64
|
||||
db *sql.DB
|
||||
logger *logger.Manager
|
||||
type e struct{}
|
||||
|
||||
insertHistory *sql.Stmt
|
||||
insertSequence *sql.Stmt
|
||||
insertConversation *sql.Stmt
|
||||
type MySQL struct {
|
||||
timeout int64
|
||||
trackAccountMessages uint32
|
||||
db *sql.DB
|
||||
logger *logger.Manager
|
||||
|
||||
insertHistory *sql.Stmt
|
||||
insertSequence *sql.Stmt
|
||||
insertConversation *sql.Stmt
|
||||
insertAccountMessage *sql.Stmt
|
||||
|
||||
stateMutex sync.Mutex
|
||||
config Config
|
||||
|
||||
wakeForgetter chan e
|
||||
}
|
||||
|
||||
func (mysql *MySQL) Initialize(logger *logger.Manager, config Config) {
|
||||
mysql.logger = logger
|
||||
mysql.wakeForgetter = make(chan e, 1)
|
||||
mysql.SetConfig(config)
|
||||
}
|
||||
|
||||
func (mysql *MySQL) SetConfig(config Config) {
|
||||
atomic.StoreInt64(&mysql.timeout, int64(config.Timeout))
|
||||
var trackAccountMessages uint32
|
||||
if config.TrackAccountMessages {
|
||||
trackAccountMessages = 1
|
||||
}
|
||||
atomic.StoreUint32(&mysql.trackAccountMessages, trackAccountMessages)
|
||||
mysql.stateMutex.Lock()
|
||||
mysql.config = config
|
||||
mysql.stateMutex.Unlock()
|
||||
|
|
@ -85,6 +108,7 @@ func (m *MySQL) Open() (err error) {
|
|||
}
|
||||
|
||||
go m.cleanupLoop()
|
||||
go m.forgetLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -109,14 +133,35 @@ func (mysql *MySQL) fixSchemas() (err error) {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaMinorVersion, latestDbMinorVersion)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
} else if err == nil && schema != latestDbSchema {
|
||||
// TODO figure out what to do about schema changes
|
||||
return &utils.IncompatibleSchemaError{CurrentVersion: schema, RequiredVersion: latestDbSchema}
|
||||
} else {
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
var minorVersion string
|
||||
err = mysql.db.QueryRow(`select value from metadata where key_name = ?;`, keySchemaMinorVersion).Scan(&minorVersion)
|
||||
if err == sql.ErrNoRows {
|
||||
// XXX for now, the only minor version upgrade is the account tracking tables
|
||||
err = mysql.createComplianceTables()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaMinorVersion, latestDbMinorVersion)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
} else if err == nil && minorVersion != latestDbMinorVersion {
|
||||
// TODO: if minorVersion < latestDbMinorVersion, upgrade,
|
||||
// if latestDbMinorVersion < minorVersion, ignore because backwards compatible
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (mysql *MySQL) createTables() (err error) {
|
||||
|
|
@ -155,6 +200,32 @@ func (mysql *MySQL) createTables() (err error) {
|
|||
return err
|
||||
}
|
||||
|
||||
err = mysql.createComplianceTables()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mysql *MySQL) createComplianceTables() (err error) {
|
||||
_, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE account_messages (
|
||||
history_id BIGINT UNSIGNED NOT NULL PRIMARY KEY,
|
||||
account VARBINARY(%[1]d) NOT NULL,
|
||||
KEY (account, history_id)
|
||||
) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE forget (
|
||||
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
|
||||
account VARBINARY(%[1]d) NOT NULL
|
||||
) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -191,7 +262,10 @@ func (mysql *MySQL) cleanupLoop() {
|
|||
}
|
||||
|
||||
func (mysql *MySQL) doCleanup(age time.Duration) (count int, err error) {
|
||||
ids, maxNanotime, err := mysql.selectCleanupIDs(age)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
|
||||
defer cancel()
|
||||
|
||||
ids, maxNanotime, err := mysql.selectCleanupIDs(ctx, age)
|
||||
if len(ids) == 0 {
|
||||
mysql.logger.Debug("mysql", "found no rows to clean up")
|
||||
return
|
||||
|
|
@ -199,6 +273,10 @@ func (mysql *MySQL) doCleanup(age time.Duration) (count int, err error) {
|
|||
|
||||
mysql.logger.Debug("mysql", fmt.Sprintf("deleting %d history rows, max age %s", len(ids), utils.NanoToTimestamp(maxNanotime)))
|
||||
|
||||
return len(ids), mysql.deleteHistoryIDs(ctx, ids)
|
||||
}
|
||||
|
||||
func (mysql *MySQL) deleteHistoryIDs(ctx context.Context, ids []uint64) (err error) {
|
||||
// can't use ? binding for a variable number of arguments, build the IN clause manually
|
||||
var inBuf bytes.Buffer
|
||||
inBuf.WriteByte('(')
|
||||
|
|
@ -210,25 +288,30 @@ func (mysql *MySQL) doCleanup(age time.Duration) (count int, err error) {
|
|||
}
|
||||
inBuf.WriteRune(')')
|
||||
|
||||
_, err = mysql.db.Exec(fmt.Sprintf(`DELETE FROM conversations WHERE history_id in %s;`, inBuf.Bytes()))
|
||||
_, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM conversations WHERE history_id in %s;`, inBuf.Bytes()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = mysql.db.Exec(fmt.Sprintf(`DELETE FROM sequence WHERE history_id in %s;`, inBuf.Bytes()))
|
||||
_, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM sequence WHERE history_id in %s;`, inBuf.Bytes()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = mysql.db.Exec(fmt.Sprintf(`DELETE FROM history WHERE id in %s;`, inBuf.Bytes()))
|
||||
if mysql.isTrackingAccountMessages() {
|
||||
_, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM account_messages WHERE history_id in %s;`, inBuf.Bytes()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
_, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM history WHERE id in %s;`, inBuf.Bytes()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
count = len(ids)
|
||||
return
|
||||
}
|
||||
|
||||
func (mysql *MySQL) selectCleanupIDs(age time.Duration) (ids []uint64, maxNanotime int64, err error) {
|
||||
rows, err := mysql.db.Query(`
|
||||
func (mysql *MySQL) selectCleanupIDs(ctx context.Context, age time.Duration) (ids []uint64, maxNanotime int64, err error) {
|
||||
rows, err := mysql.db.QueryContext(ctx, `
|
||||
SELECT history.id, sequence.nanotime
|
||||
FROM history
|
||||
LEFT JOIN sequence ON history.id = sequence.history_id
|
||||
|
|
@ -266,6 +349,109 @@ func (mysql *MySQL) selectCleanupIDs(age time.Duration) (ids []uint64, maxNanoti
|
|||
return
|
||||
}
|
||||
|
||||
// wait for forget queue items and process them one by one
|
||||
func (mysql *MySQL) forgetLoop() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
mysql.logger.Error("mysql",
|
||||
fmt.Sprintf("Panic in forget routine: %v\n%s", r, debug.Stack()))
|
||||
time.Sleep(cleanupPauseTime)
|
||||
go mysql.forgetLoop()
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
for {
|
||||
found, err := mysql.doForget()
|
||||
mysql.logError("error processing forget", err)
|
||||
if err != nil {
|
||||
time.Sleep(cleanupPauseTime)
|
||||
}
|
||||
if !found {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
<-mysql.wakeForgetter
|
||||
}
|
||||
}
|
||||
|
||||
// dequeue an item from the forget queue and process it
|
||||
func (mysql *MySQL) doForget() (found bool, err error) {
|
||||
id, account, err := func() (id int64, account string, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
|
||||
defer cancel()
|
||||
|
||||
row := mysql.db.QueryRowContext(ctx,
|
||||
`SELECT forget.id, forget.account FROM forget LIMIT 1;`)
|
||||
err = row.Scan(&id, &account)
|
||||
if err == sql.ErrNoRows {
|
||||
return 0, "", nil
|
||||
}
|
||||
return
|
||||
}()
|
||||
|
||||
if err != nil || account == "" {
|
||||
return false, err
|
||||
}
|
||||
|
||||
found = true
|
||||
|
||||
var count int
|
||||
for {
|
||||
start := time.Now()
|
||||
count, err = mysql.doForgetIteration(account)
|
||||
elapsed := time.Since(start)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
if count == 0 {
|
||||
break
|
||||
}
|
||||
time.Sleep(elapsed)
|
||||
}
|
||||
|
||||
mysql.logger.Debug("mysql", "forget complete for account", account)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
|
||||
defer cancel()
|
||||
_, err = mysql.db.ExecContext(ctx, `DELETE FROM forget where id = ?;`, id)
|
||||
return
|
||||
}
|
||||
|
||||
func (mysql *MySQL) doForgetIteration(account string) (count int, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
|
||||
defer cancel()
|
||||
|
||||
rows, err := mysql.db.QueryContext(ctx, `
|
||||
SELECT account_messages.history_id
|
||||
FROM account_messages
|
||||
WHERE account_messages.account = ?
|
||||
LIMIT ?;`, account, cleanupRowLimit)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var ids []uint64
|
||||
for rows.Next() {
|
||||
var id uint64
|
||||
err = rows.Scan(&id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
if len(ids) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
mysql.logger.Debug("mysql", fmt.Sprintf("deleting %d history rows from account %s", len(ids), account))
|
||||
err = mysql.deleteHistoryIDs(ctx, ids)
|
||||
return len(ids), err
|
||||
}
|
||||
|
||||
func (mysql *MySQL) prepareStatements() (err error) {
|
||||
mysql.insertHistory, err = mysql.db.Prepare(`INSERT INTO history
|
||||
(data, msgid) VALUES (?, ?);`)
|
||||
|
|
@ -282,6 +468,11 @@ func (mysql *MySQL) prepareStatements() (err error) {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
mysql.insertAccountMessage, err = mysql.db.Prepare(`INSERT INTO account_messages
|
||||
(history_id, account) VALUES (?, ?);`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
|
@ -290,6 +481,10 @@ func (mysql *MySQL) getTimeout() time.Duration {
|
|||
return time.Duration(atomic.LoadInt64(&mysql.timeout))
|
||||
}
|
||||
|
||||
func (mysql *MySQL) isTrackingAccountMessages() bool {
|
||||
return atomic.LoadUint32(&mysql.trackAccountMessages) != 0
|
||||
}
|
||||
|
||||
func (mysql *MySQL) logError(context string, err error) (quit bool) {
|
||||
if err != nil {
|
||||
mysql.logger.Error("mysql", context, err.Error())
|
||||
|
|
@ -298,7 +493,27 @@ func (mysql *MySQL) logError(context string, err error) (quit bool) {
|
|||
return false
|
||||
}
|
||||
|
||||
func (mysql *MySQL) AddChannelItem(target string, item history.Item) (err error) {
|
||||
func (mysql *MySQL) Forget(account string) {
|
||||
if mysql.db == nil || account == "" {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
|
||||
defer cancel()
|
||||
|
||||
_, err := mysql.db.ExecContext(ctx, `INSERT INTO forget (account) VALUES (?);`, account)
|
||||
if mysql.logError("can't insert into forget table", err) {
|
||||
return
|
||||
}
|
||||
|
||||
// wake up the forget goroutine if it's blocked:
|
||||
select {
|
||||
case mysql.wakeForgetter <- e{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (mysql *MySQL) AddChannelItem(target string, item history.Item, account string) (err error) {
|
||||
if mysql.db == nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -316,6 +531,15 @@ func (mysql *MySQL) AddChannelItem(target string, item history.Item) (err error)
|
|||
}
|
||||
|
||||
err = mysql.insertSequenceEntry(ctx, target, item.Message.Time.UnixNano(), id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = mysql.insertAccountMessageEntry(ctx, id, account)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -354,6 +578,15 @@ func (mysql *MySQL) insertBase(ctx context.Context, item history.Item) (id int64
|
|||
return
|
||||
}
|
||||
|
||||
func (mysql *MySQL) insertAccountMessageEntry(ctx context.Context, id int64, account string) (err error) {
|
||||
if account == "" || !mysql.isTrackingAccountMessages() {
|
||||
return
|
||||
}
|
||||
_, err = mysql.insertAccountMessage.ExecContext(ctx, id, account)
|
||||
mysql.logError("could not insert account-message entry", err)
|
||||
return
|
||||
}
|
||||
|
||||
func (mysql *MySQL) AddDirectMessage(sender, senderAccount, recipient, recipientAccount string, item history.Item) (err error) {
|
||||
if mysql.db == nil {
|
||||
return
|
||||
|
|
@ -399,10 +632,102 @@ func (mysql *MySQL) AddDirectMessage(sender, senderAccount, recipient, recipient
|
|||
}
|
||||
}
|
||||
|
||||
err = mysql.insertAccountMessageEntry(ctx, id, senderAccount)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (mysql *MySQL) msgidToTime(ctx context.Context, msgid string) (result time.Time, err error) {
|
||||
// note that accountName is the unfolded name
|
||||
func (mysql *MySQL) DeleteMsgid(msgid, accountName string) (err error) {
|
||||
if mysql.db == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
|
||||
defer cancel()
|
||||
|
||||
_, id, data, err := mysql.lookupMsgid(ctx, msgid, true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if accountName != "*" {
|
||||
var item history.Item
|
||||
err = unmarshalItem(data, &item)
|
||||
// delete if the entry is corrupt
|
||||
if err == nil && item.AccountName != accountName {
|
||||
return ErrDisallowed
|
||||
}
|
||||
}
|
||||
|
||||
err = mysql.deleteHistoryIDs(ctx, []uint64{id})
|
||||
mysql.logError("couldn't delete msgid", err)
|
||||
return
|
||||
}
|
||||
|
||||
func (mysql *MySQL) Export(account string, writer io.Writer) {
|
||||
if mysql.db == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var err error
|
||||
var lastSeen uint64
|
||||
for {
|
||||
rows := func() (count int) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime)
|
||||
defer cancel()
|
||||
|
||||
rows, rowsErr := mysql.db.QueryContext(ctx, `
|
||||
SELECT account_messages.history_id, history.data, sequence.target FROM account_messages
|
||||
INNER JOIN history ON history.id = account_messages.history_id
|
||||
INNER JOIN sequence ON account_messages.history_id = sequence.history_id
|
||||
WHERE account_messages.account = ? AND account_messages.history_id > ?
|
||||
LIMIT ?`, account, lastSeen, cleanupRowLimit)
|
||||
if rowsErr != nil {
|
||||
err = rowsErr
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var id uint64
|
||||
var blob, jsonBlob []byte
|
||||
var target string
|
||||
var item history.Item
|
||||
err = rows.Scan(&id, &blob, &target)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = unmarshalItem(blob, &item)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
item.CfCorrespondent = target
|
||||
jsonBlob, err = json.Marshal(item)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
count++
|
||||
if lastSeen < id {
|
||||
lastSeen = id
|
||||
}
|
||||
writer.Write(jsonBlob)
|
||||
writer.Write([]byte{'\n'})
|
||||
}
|
||||
return
|
||||
}()
|
||||
if rows == 0 || err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
mysql.logError("could not export history", err)
|
||||
return
|
||||
}
|
||||
|
||||
func (mysql *MySQL) lookupMsgid(ctx context.Context, msgid string, includeData bool) (result time.Time, id uint64, data []byte, err error) {
|
||||
// in theory, we could optimize out a roundtrip to the database by using a subquery instead:
|
||||
// sequence.nanotime > (
|
||||
// SELECT sequence.nanotime FROM sequence, history
|
||||
|
|
@ -415,15 +740,27 @@ func (mysql *MySQL) msgidToTime(ctx context.Context, msgid string) (result time.
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
row := mysql.db.QueryRowContext(ctx, `
|
||||
SELECT sequence.nanotime FROM sequence
|
||||
cols := `sequence.nanotime`
|
||||
if includeData {
|
||||
cols = `sequence.nanotime, sequence.history_id, history.data`
|
||||
}
|
||||
row := mysql.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT %s FROM sequence
|
||||
INNER JOIN history ON history.id = sequence.history_id
|
||||
WHERE history.msgid = ? LIMIT 1;`, decoded)
|
||||
WHERE history.msgid = ? LIMIT 1;`, cols), decoded)
|
||||
var nanotime int64
|
||||
err = row.Scan(&nanotime)
|
||||
if mysql.logError("could not resolve msgid to time", err) {
|
||||
if !includeData {
|
||||
err = row.Scan(&nanotime)
|
||||
} else {
|
||||
err = row.Scan(&nanotime, &id, &data)
|
||||
}
|
||||
if err != sql.ErrNoRows {
|
||||
mysql.logError("could not resolve msgid to time", err)
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
result = time.Unix(0, nanotime).UTC()
|
||||
return
|
||||
}
|
||||
|
|
@ -519,14 +856,14 @@ func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (
|
|||
|
||||
startTime := start.Time
|
||||
if start.Msgid != "" {
|
||||
startTime, err = s.mysql.msgidToTime(ctx, start.Msgid)
|
||||
startTime, _, _, err = s.mysql.lookupMsgid(ctx, start.Msgid, false)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
}
|
||||
endTime := end.Time
|
||||
if end.Msgid != "" {
|
||||
endTime, err = s.mysql.msgidToTime(ctx, end.Msgid)
|
||||
endTime, _, _, err = s.mysql.lookupMsgid(ctx, end.Msgid, false)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue