Implements the new `CHATHISTORY LISTCORRESPONDENTS` API.
This commit is contained in:
Shivaram Lingamneni 2021-04-06 00:46:07 -04:00
parent 2e9a0d4b2d
commit 4052cd12fe
9 changed files with 320 additions and 52 deletions

View file

@ -4,7 +4,6 @@
package mysql
import (
"bytes"
"context"
"database/sql"
"encoding/json"
@ -12,6 +11,7 @@ import (
"fmt"
"io"
"runtime/debug"
"strings"
"sync"
"sync/atomic"
"time"
@ -36,7 +36,7 @@ const (
keySchemaVersion = "db.version"
// minor version indicates rollback-safe upgrades, i.e.,
// you can downgrade oragono and everything will work
latestDbMinorVersion = "1"
latestDbMinorVersion = "2"
keySchemaMinorVersion = "db.minorversion"
cleanupRowLimit = 50
cleanupPauseTime = 10 * time.Minute
@ -53,6 +53,7 @@ type MySQL struct {
insertHistory *sql.Stmt
insertSequence *sql.Stmt
insertConversation *sql.Stmt
insertCorrespondent *sql.Stmt
insertAccountMessage *sql.Stmt
stateMutex sync.Mutex
@ -155,10 +156,24 @@ func (mysql *MySQL) fixSchemas() (err error) {
if err != nil {
return
}
err = mysql.createCorrespondentsTable()
if err != nil {
return
}
_, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaMinorVersion, latestDbMinorVersion)
if err != nil {
return
}
} else if err == nil && minorVersion == "1" {
// upgrade from 2.1 to 2.2: create the correspondents table
err = mysql.createCorrespondentsTable()
if err != nil {
return
}
_, err = mysql.db.Exec(`update metadata set value = ? where key_name = ?;`, latestDbMinorVersion, keySchemaMinorVersion)
if err != nil {
return
}
} else if err == nil && minorVersion != latestDbMinorVersion {
// TODO: if minorVersion < latestDbMinorVersion, upgrade,
// if latestDbMinorVersion < minorVersion, ignore because backwards compatible
@ -202,6 +217,11 @@ func (mysql *MySQL) createTables() (err error) {
return err
}
err = mysql.createCorrespondentsTable()
if err != nil {
return err
}
err = mysql.createComplianceTables()
if err != nil {
return err
@ -210,6 +230,19 @@ func (mysql *MySQL) createTables() (err error) {
return nil
}
func (mysql *MySQL) createCorrespondentsTable() (err error) {
_, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE correspondents (
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
target VARBINARY(%[1]d) NOT NULL,
correspondent VARBINARY(%[1]d) NOT NULL,
nanotime BIGINT UNSIGNED NOT NULL,
UNIQUE KEY (target, correspondent),
KEY (target, nanotime),
KEY (nanotime)
) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
return
}
func (mysql *MySQL) createComplianceTables() (err error) {
_, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE account_messages (
history_id BIGINT UNSIGNED NOT NULL PRIMARY KEY,
@ -275,12 +308,16 @@ func (mysql *MySQL) doCleanup(age time.Duration) (count int, err error) {
mysql.logger.Debug("mysql", fmt.Sprintf("deleting %d history rows, max age %s", len(ids), utils.NanoToTimestamp(maxNanotime)))
if maxNanotime != 0 {
mysql.deleteCorrespondents(ctx, maxNanotime)
}
return len(ids), mysql.deleteHistoryIDs(ctx, ids)
}
func (mysql *MySQL) deleteHistoryIDs(ctx context.Context, ids []uint64) (err error) {
// can't use ? binding for a variable number of arguments, build the IN clause manually
var inBuf bytes.Buffer
var inBuf strings.Builder
inBuf.WriteByte('(')
for i, id := range ids {
if i != 0 {
@ -289,22 +326,23 @@ func (mysql *MySQL) deleteHistoryIDs(ctx context.Context, ids []uint64) (err err
fmt.Fprintf(&inBuf, "%d", id)
}
inBuf.WriteRune(')')
inClause := inBuf.String()
_, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM conversations WHERE history_id in %s;`, inBuf.Bytes()))
_, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM conversations WHERE history_id in %s;`, inClause))
if err != nil {
return
}
_, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM sequence WHERE history_id in %s;`, inBuf.Bytes()))
_, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM sequence WHERE history_id in %s;`, inClause))
if err != nil {
return
}
if mysql.isTrackingAccountMessages() {
_, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM account_messages WHERE history_id in %s;`, inBuf.Bytes()))
_, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM account_messages WHERE history_id in %s;`, inClause))
if err != nil {
return
}
}
_, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM history WHERE id in %s;`, inBuf.Bytes()))
_, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM history WHERE id in %s;`, inClause))
if err != nil {
return
}
@ -351,6 +389,18 @@ func (mysql *MySQL) selectCleanupIDs(ctx context.Context, age time.Duration) (id
return
}
func (mysql *MySQL) deleteCorrespondents(ctx context.Context, threshold int64) {
result, err := mysql.db.ExecContext(ctx, `DELETE FROM correspondents WHERE nanotime <= (?);`, threshold)
if err != nil {
mysql.logError("error deleting correspondents", err)
} else {
count, err := result.RowsAffected()
if err != nil {
mysql.logger.Debug(fmt.Sprintf("deleted %d correspondents entries", count))
}
}
}
// wait for forget queue items and process them one by one
func (mysql *MySQL) forgetLoop() {
defer func() {
@ -470,6 +520,12 @@ func (mysql *MySQL) prepareStatements() (err error) {
if err != nil {
return
}
mysql.insertCorrespondent, err = mysql.db.Prepare(`INSERT INTO correspondents
(target, correspondent, nanotime) VALUES (?, ?, ?)
ON DUPLICATE KEY UPDATE nanotime = GREATEST(nanotime, ?);`)
if err != nil {
return
}
mysql.insertAccountMessage, err = mysql.db.Prepare(`INSERT INTO account_messages
(history_id, account) VALUES (?, ?);`)
if err != nil {
@ -557,6 +613,12 @@ func (mysql *MySQL) insertConversationEntry(ctx context.Context, target, corresp
return
}
func (mysql *MySQL) insertCorrespondentsEntry(ctx context.Context, target, correspondent string, messageTime int64, historyId int64) (err error) {
_, err = mysql.insertCorrespondent.ExecContext(ctx, target, correspondent, messageTime, messageTime)
mysql.logError("could not insert conversations entry", err)
return
}
func (mysql *MySQL) insertBase(ctx context.Context, item history.Item) (id int64, err error) {
value, err := marshalItem(&item)
if mysql.logError("could not marshal item", err) {
@ -621,6 +683,10 @@ func (mysql *MySQL) AddDirectMessage(sender, senderAccount, recipient, recipient
if err != nil {
return
}
err = mysql.insertCorrespondentsEntry(ctx, senderAccount, recipient, nanotime, id)
if err != nil {
return
}
}
if recipientAccount != "" && sender != recipient {
@ -632,6 +698,10 @@ func (mysql *MySQL) AddDirectMessage(sender, senderAccount, recipient, recipient
if err != nil {
return
}
err = mysql.insertCorrespondentsEntry(ctx, recipientAccount, sender, nanotime, id)
if err != nil {
return
}
}
err = mysql.insertAccountMessageEntry(ctx, id, senderAccount)
@ -804,7 +874,7 @@ func (mysql *MySQL) betweenTimestamps(ctx context.Context, target, correspondent
direction = "DESC"
}
var queryBuf bytes.Buffer
var queryBuf strings.Builder
args := make([]interface{}, 0, 6)
fmt.Fprintf(&queryBuf,
@ -835,6 +905,55 @@ func (mysql *MySQL) betweenTimestamps(ctx context.Context, target, correspondent
return
}
func (mysql *MySQL) listCorrespondentsInternal(ctx context.Context, target string, after, before, cutoff time.Time, limit int) (results []history.CorrespondentListing, err error) {
after, before, ascending := history.MinMaxAsc(after, before, cutoff)
direction := "ASC"
if !ascending {
direction = "DESC"
}
var queryBuf strings.Builder
args := make([]interface{}, 0, 4)
queryBuf.WriteString(`SELECT correspondents.correspondent, correspondents.nanotime from correspondents
WHERE target = ?`)
args = append(args, target)
if !after.IsZero() {
queryBuf.WriteString(" AND correspondents.nanotime > ?")
args = append(args, after.UnixNano())
}
if !before.IsZero() {
queryBuf.WriteString(" AND correspondents.nanotime < ?")
args = append(args, before.UnixNano())
}
fmt.Fprintf(&queryBuf, " ORDER BY correspondents.nanotime %s LIMIT ?;", direction)
args = append(args, limit)
query := queryBuf.String()
rows, err := mysql.db.QueryContext(ctx, query, args...)
if err != nil {
return
}
defer rows.Close()
var correspondent string
var nanotime int64
for rows.Next() {
err = rows.Scan(&correspondent, &nanotime)
if err != nil {
return
}
results = append(results, history.CorrespondentListing{
CfCorrespondent: correspondent,
Time: time.Unix(0, nanotime),
})
}
if !ascending {
history.ReverseCorrespondents(results)
}
return
}
func (mysql *MySQL) Close() {
// closing the database will close our prepared statements as well
if mysql.db != nil {
@ -852,7 +971,7 @@ type mySQLHistorySequence struct {
cutoff time.Time
}
func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (results []history.Item, complete bool, err error) {
func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (results []history.Item, err error) {
ctx, cancel := context.WithTimeout(context.Background(), s.mysql.getTimeout())
defer cancel()
@ -860,25 +979,38 @@ func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (
if start.Msgid != "" {
startTime, _, _, err = s.mysql.lookupMsgid(ctx, start.Msgid, false)
if err != nil {
return nil, false, err
return nil, err
}
}
endTime := end.Time
if end.Msgid != "" {
endTime, _, _, err = s.mysql.lookupMsgid(ctx, end.Msgid, false)
if err != nil {
return nil, false, err
return nil, err
}
}
results, err = s.mysql.betweenTimestamps(ctx, s.target, s.correspondent, startTime, endTime, s.cutoff, limit)
return results, (err == nil), err
return results, err
}
func (s *mySQLHistorySequence) Around(start history.Selector, limit int) (results []history.Item, err error) {
return history.GenericAround(s, start, limit)
}
func (seq *mySQLHistorySequence) ListCorrespondents(start, end history.Selector, limit int) (results []history.CorrespondentListing, err error) {
ctx, cancel := context.WithTimeout(context.Background(), seq.mysql.getTimeout())
defer cancel()
// TODO accept msgids here?
startTime := start.Time
endTime := end.Time
results, err = seq.mysql.listCorrespondentsInternal(ctx, seq.target, startTime, endTime, seq.cutoff, limit)
seq.mysql.logError("could not read correspondents", err)
return
}
func (mysql *MySQL) MakeSequence(target, correspondent string, cutoff time.Time) history.Sequence {
return &mySQLHistorySequence{
target: target,