mirror of
https://github.com/ergochat/ergo.git
synced 2025-12-20 02:00:11 -08:00
parent
2e9a0d4b2d
commit
4052cd12fe
9 changed files with 320 additions and 52 deletions
|
|
@ -4,7 +4,6 @@
|
|||
package mysql
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
|
|
@ -12,6 +11,7 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
|
@ -36,7 +36,7 @@ const (
|
|||
keySchemaVersion = "db.version"
|
||||
// minor version indicates rollback-safe upgrades, i.e.,
|
||||
// you can downgrade oragono and everything will work
|
||||
latestDbMinorVersion = "1"
|
||||
latestDbMinorVersion = "2"
|
||||
keySchemaMinorVersion = "db.minorversion"
|
||||
cleanupRowLimit = 50
|
||||
cleanupPauseTime = 10 * time.Minute
|
||||
|
|
@ -53,6 +53,7 @@ type MySQL struct {
|
|||
insertHistory *sql.Stmt
|
||||
insertSequence *sql.Stmt
|
||||
insertConversation *sql.Stmt
|
||||
insertCorrespondent *sql.Stmt
|
||||
insertAccountMessage *sql.Stmt
|
||||
|
||||
stateMutex sync.Mutex
|
||||
|
|
@ -155,10 +156,24 @@ func (mysql *MySQL) fixSchemas() (err error) {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = mysql.createCorrespondentsTable()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaMinorVersion, latestDbMinorVersion)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
} else if err == nil && minorVersion == "1" {
|
||||
// upgrade from 2.1 to 2.2: create the correspondents table
|
||||
err = mysql.createCorrespondentsTable()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = mysql.db.Exec(`update metadata set value = ? where key_name = ?;`, latestDbMinorVersion, keySchemaMinorVersion)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
} else if err == nil && minorVersion != latestDbMinorVersion {
|
||||
// TODO: if minorVersion < latestDbMinorVersion, upgrade,
|
||||
// if latestDbMinorVersion < minorVersion, ignore because backwards compatible
|
||||
|
|
@ -202,6 +217,11 @@ func (mysql *MySQL) createTables() (err error) {
|
|||
return err
|
||||
}
|
||||
|
||||
err = mysql.createCorrespondentsTable()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = mysql.createComplianceTables()
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -210,6 +230,19 @@ func (mysql *MySQL) createTables() (err error) {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (mysql *MySQL) createCorrespondentsTable() (err error) {
|
||||
_, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE correspondents (
|
||||
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
|
||||
target VARBINARY(%[1]d) NOT NULL,
|
||||
correspondent VARBINARY(%[1]d) NOT NULL,
|
||||
nanotime BIGINT UNSIGNED NOT NULL,
|
||||
UNIQUE KEY (target, correspondent),
|
||||
KEY (target, nanotime),
|
||||
KEY (nanotime)
|
||||
) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
|
||||
return
|
||||
}
|
||||
|
||||
func (mysql *MySQL) createComplianceTables() (err error) {
|
||||
_, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE account_messages (
|
||||
history_id BIGINT UNSIGNED NOT NULL PRIMARY KEY,
|
||||
|
|
@ -275,12 +308,16 @@ func (mysql *MySQL) doCleanup(age time.Duration) (count int, err error) {
|
|||
|
||||
mysql.logger.Debug("mysql", fmt.Sprintf("deleting %d history rows, max age %s", len(ids), utils.NanoToTimestamp(maxNanotime)))
|
||||
|
||||
if maxNanotime != 0 {
|
||||
mysql.deleteCorrespondents(ctx, maxNanotime)
|
||||
}
|
||||
|
||||
return len(ids), mysql.deleteHistoryIDs(ctx, ids)
|
||||
}
|
||||
|
||||
func (mysql *MySQL) deleteHistoryIDs(ctx context.Context, ids []uint64) (err error) {
|
||||
// can't use ? binding for a variable number of arguments, build the IN clause manually
|
||||
var inBuf bytes.Buffer
|
||||
var inBuf strings.Builder
|
||||
inBuf.WriteByte('(')
|
||||
for i, id := range ids {
|
||||
if i != 0 {
|
||||
|
|
@ -289,22 +326,23 @@ func (mysql *MySQL) deleteHistoryIDs(ctx context.Context, ids []uint64) (err err
|
|||
fmt.Fprintf(&inBuf, "%d", id)
|
||||
}
|
||||
inBuf.WriteRune(')')
|
||||
inClause := inBuf.String()
|
||||
|
||||
_, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM conversations WHERE history_id in %s;`, inBuf.Bytes()))
|
||||
_, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM conversations WHERE history_id in %s;`, inClause))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM sequence WHERE history_id in %s;`, inBuf.Bytes()))
|
||||
_, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM sequence WHERE history_id in %s;`, inClause))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if mysql.isTrackingAccountMessages() {
|
||||
_, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM account_messages WHERE history_id in %s;`, inBuf.Bytes()))
|
||||
_, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM account_messages WHERE history_id in %s;`, inClause))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
_, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM history WHERE id in %s;`, inBuf.Bytes()))
|
||||
_, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM history WHERE id in %s;`, inClause))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -351,6 +389,18 @@ func (mysql *MySQL) selectCleanupIDs(ctx context.Context, age time.Duration) (id
|
|||
return
|
||||
}
|
||||
|
||||
func (mysql *MySQL) deleteCorrespondents(ctx context.Context, threshold int64) {
|
||||
result, err := mysql.db.ExecContext(ctx, `DELETE FROM correspondents WHERE nanotime <= (?);`, threshold)
|
||||
if err != nil {
|
||||
mysql.logError("error deleting correspondents", err)
|
||||
} else {
|
||||
count, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
mysql.logger.Debug(fmt.Sprintf("deleted %d correspondents entries", count))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// wait for forget queue items and process them one by one
|
||||
func (mysql *MySQL) forgetLoop() {
|
||||
defer func() {
|
||||
|
|
@ -470,6 +520,12 @@ func (mysql *MySQL) prepareStatements() (err error) {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
mysql.insertCorrespondent, err = mysql.db.Prepare(`INSERT INTO correspondents
|
||||
(target, correspondent, nanotime) VALUES (?, ?, ?)
|
||||
ON DUPLICATE KEY UPDATE nanotime = GREATEST(nanotime, ?);`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
mysql.insertAccountMessage, err = mysql.db.Prepare(`INSERT INTO account_messages
|
||||
(history_id, account) VALUES (?, ?);`)
|
||||
if err != nil {
|
||||
|
|
@ -557,6 +613,12 @@ func (mysql *MySQL) insertConversationEntry(ctx context.Context, target, corresp
|
|||
return
|
||||
}
|
||||
|
||||
func (mysql *MySQL) insertCorrespondentsEntry(ctx context.Context, target, correspondent string, messageTime int64, historyId int64) (err error) {
|
||||
_, err = mysql.insertCorrespondent.ExecContext(ctx, target, correspondent, messageTime, messageTime)
|
||||
mysql.logError("could not insert conversations entry", err)
|
||||
return
|
||||
}
|
||||
|
||||
func (mysql *MySQL) insertBase(ctx context.Context, item history.Item) (id int64, err error) {
|
||||
value, err := marshalItem(&item)
|
||||
if mysql.logError("could not marshal item", err) {
|
||||
|
|
@ -621,6 +683,10 @@ func (mysql *MySQL) AddDirectMessage(sender, senderAccount, recipient, recipient
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = mysql.insertCorrespondentsEntry(ctx, senderAccount, recipient, nanotime, id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if recipientAccount != "" && sender != recipient {
|
||||
|
|
@ -632,6 +698,10 @@ func (mysql *MySQL) AddDirectMessage(sender, senderAccount, recipient, recipient
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = mysql.insertCorrespondentsEntry(ctx, recipientAccount, sender, nanotime, id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = mysql.insertAccountMessageEntry(ctx, id, senderAccount)
|
||||
|
|
@ -804,7 +874,7 @@ func (mysql *MySQL) betweenTimestamps(ctx context.Context, target, correspondent
|
|||
direction = "DESC"
|
||||
}
|
||||
|
||||
var queryBuf bytes.Buffer
|
||||
var queryBuf strings.Builder
|
||||
|
||||
args := make([]interface{}, 0, 6)
|
||||
fmt.Fprintf(&queryBuf,
|
||||
|
|
@ -835,6 +905,55 @@ func (mysql *MySQL) betweenTimestamps(ctx context.Context, target, correspondent
|
|||
return
|
||||
}
|
||||
|
||||
func (mysql *MySQL) listCorrespondentsInternal(ctx context.Context, target string, after, before, cutoff time.Time, limit int) (results []history.CorrespondentListing, err error) {
|
||||
after, before, ascending := history.MinMaxAsc(after, before, cutoff)
|
||||
direction := "ASC"
|
||||
if !ascending {
|
||||
direction = "DESC"
|
||||
}
|
||||
|
||||
var queryBuf strings.Builder
|
||||
args := make([]interface{}, 0, 4)
|
||||
queryBuf.WriteString(`SELECT correspondents.correspondent, correspondents.nanotime from correspondents
|
||||
WHERE target = ?`)
|
||||
args = append(args, target)
|
||||
if !after.IsZero() {
|
||||
queryBuf.WriteString(" AND correspondents.nanotime > ?")
|
||||
args = append(args, after.UnixNano())
|
||||
}
|
||||
if !before.IsZero() {
|
||||
queryBuf.WriteString(" AND correspondents.nanotime < ?")
|
||||
args = append(args, before.UnixNano())
|
||||
}
|
||||
fmt.Fprintf(&queryBuf, " ORDER BY correspondents.nanotime %s LIMIT ?;", direction)
|
||||
args = append(args, limit)
|
||||
query := queryBuf.String()
|
||||
|
||||
rows, err := mysql.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
var correspondent string
|
||||
var nanotime int64
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&correspondent, &nanotime)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
results = append(results, history.CorrespondentListing{
|
||||
CfCorrespondent: correspondent,
|
||||
Time: time.Unix(0, nanotime),
|
||||
})
|
||||
}
|
||||
|
||||
if !ascending {
|
||||
history.ReverseCorrespondents(results)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (mysql *MySQL) Close() {
|
||||
// closing the database will close our prepared statements as well
|
||||
if mysql.db != nil {
|
||||
|
|
@ -852,7 +971,7 @@ type mySQLHistorySequence struct {
|
|||
cutoff time.Time
|
||||
}
|
||||
|
||||
func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (results []history.Item, complete bool, err error) {
|
||||
func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (results []history.Item, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), s.mysql.getTimeout())
|
||||
defer cancel()
|
||||
|
||||
|
|
@ -860,25 +979,38 @@ func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (
|
|||
if start.Msgid != "" {
|
||||
startTime, _, _, err = s.mysql.lookupMsgid(ctx, start.Msgid, false)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
endTime := end.Time
|
||||
if end.Msgid != "" {
|
||||
endTime, _, _, err = s.mysql.lookupMsgid(ctx, end.Msgid, false)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
results, err = s.mysql.betweenTimestamps(ctx, s.target, s.correspondent, startTime, endTime, s.cutoff, limit)
|
||||
return results, (err == nil), err
|
||||
return results, err
|
||||
}
|
||||
|
||||
func (s *mySQLHistorySequence) Around(start history.Selector, limit int) (results []history.Item, err error) {
|
||||
return history.GenericAround(s, start, limit)
|
||||
}
|
||||
|
||||
func (seq *mySQLHistorySequence) ListCorrespondents(start, end history.Selector, limit int) (results []history.CorrespondentListing, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), seq.mysql.getTimeout())
|
||||
defer cancel()
|
||||
|
||||
// TODO accept msgids here?
|
||||
startTime := start.Time
|
||||
endTime := end.Time
|
||||
|
||||
results, err = seq.mysql.listCorrespondentsInternal(ctx, seq.target, startTime, endTime, seq.cutoff, limit)
|
||||
seq.mysql.logError("could not read correspondents", err)
|
||||
return
|
||||
}
|
||||
|
||||
func (mysql *MySQL) MakeSequence(target, correspondent string, cutoff time.Time) history.Sequence {
|
||||
return &mySQLHistorySequence{
|
||||
target: target,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue