forked from External/ergo
reactions
This commit is contained in:
parent
d73b6bac86
commit
3b0fecd381
9 changed files with 189 additions and 115 deletions
|
|
@ -6,6 +6,7 @@ package mysql
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
|
@ -29,17 +30,9 @@ var (
|
|||
const (
|
||||
// maximum length in bytes of any message target (nickname or channel name) in its
|
||||
// canonicalized (i.e., casefolded) state:
|
||||
MaxTargetLength = 64
|
||||
|
||||
// latest schema of the db
|
||||
latestDbSchema = "2"
|
||||
keySchemaVersion = "db.version"
|
||||
// minor version indicates rollback-safe upgrades, i.e.,
|
||||
// you can downgrade oragono and everything will work
|
||||
latestDbMinorVersion = "2"
|
||||
keySchemaMinorVersion = "db.minorversion"
|
||||
cleanupRowLimit = 50
|
||||
cleanupPauseTime = 10 * time.Minute
|
||||
MaxTargetLength = 64
|
||||
cleanupRowLimit = 50
|
||||
cleanupPauseTime = 10 * time.Minute
|
||||
)
|
||||
|
||||
type e struct{}
|
||||
|
|
@ -52,6 +45,13 @@ type MySQL struct {
|
|||
insertConversation *sql.Stmt
|
||||
insertAccountMessage *sql.Stmt
|
||||
|
||||
getReactionsQuery *sql.Stmt
|
||||
getSingleReaction *sql.Stmt
|
||||
addReaction *sql.Stmt
|
||||
deleteReaction *sql.Stmt
|
||||
|
||||
getMessageById *sql.Stmt
|
||||
|
||||
stateMutex sync.Mutex
|
||||
config Config
|
||||
|
||||
|
|
@ -124,102 +124,17 @@ func (mysql *MySQL) Open() (err error) {
|
|||
}
|
||||
|
||||
func (mysql *MySQL) fixSchemas() (err error) {
|
||||
_, err = mysql.db.Exec(`CREATE TABLE IF NOT EXISTS metadata (
|
||||
key_name VARCHAR(32) primary key,
|
||||
value VARCHAR(32) NOT NULL
|
||||
) CHARSET=ascii COLLATE=ascii_bin;`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var schema string
|
||||
err = mysql.db.QueryRow(`select value from metadata where key_name = ?;`, keySchemaVersion).Scan(&schema)
|
||||
if err == sql.ErrNoRows {
|
||||
err = mysql.createTables()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaVersion, latestDbSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaMinorVersion, latestDbMinorVersion)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
} else if err == nil && schema != latestDbSchema {
|
||||
// TODO figure out what to do about schema changes
|
||||
return fmt.Errorf("incompatible schema: got %s, expected %s", schema, latestDbSchema)
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var minorVersion string
|
||||
err = mysql.db.QueryRow(`select value from metadata where key_name = ?;`, keySchemaMinorVersion).Scan(&minorVersion)
|
||||
if err == sql.ErrNoRows {
|
||||
// XXX for now, the only minor version upgrade is the account tracking tables
|
||||
err = mysql.createComplianceTables()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaMinorVersion, latestDbMinorVersion)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
} else if err == nil && minorVersion == "1" {
|
||||
// upgrade from 2.1 to 2.2: create the correspondents table
|
||||
_, err = mysql.db.Exec(`update metadata set value = ? where key_name = ?;`, latestDbMinorVersion, keySchemaMinorVersion)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
} else if err == nil && minorVersion != latestDbMinorVersion {
|
||||
// TODO: if minorVersion < latestDbMinorVersion, upgrade,
|
||||
// if latestDbMinorVersion < minorVersion, ignore because backwards compatible
|
||||
}
|
||||
// 3M now handles this
|
||||
return
|
||||
}
|
||||
|
||||
func (mysql *MySQL) createTables() (err error) {
|
||||
_, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE history (
|
||||
msgid BINARY(16) NOT NULL PRIMARY KEY,
|
||||
data BLOB NOT NULL,
|
||||
target VARBINARY(%[1]d) NOT NULL,
|
||||
sender VARBINARY(%[1]d) NOT NULL,
|
||||
nanotime BIGINT UNSIGNED NOT NULL,
|
||||
pm boolean as (SUBSTRING(target, 1, 1) != "#") PERSISTENT,
|
||||
KEY (msgid(4))
|
||||
) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength, MaxTargetLength))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = mysql.createComplianceTables()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 3M now handles this
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mysql *MySQL) createComplianceTables() (err error) {
|
||||
_, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE account_messages (
|
||||
history_id BIGINT UNSIGNED NOT NULL PRIMARY KEY,
|
||||
account VARBINARY(%[1]d) NOT NULL,
|
||||
KEY (account, history_id)
|
||||
) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE forget (
|
||||
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
|
||||
account VARBINARY(%[1]d) NOT NULL
|
||||
) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 3M now handles this
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -289,7 +204,8 @@ func (mysql *MySQL) deleteHistoryIDs(ctx context.Context, ids []uint64) (err err
|
|||
return
|
||||
}
|
||||
}
|
||||
_, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM history WHERE id in %s;`, inClause))
|
||||
fmt.Printf(`DELETE FROM history WHERE msgid in %s;`, inClause)
|
||||
_, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM history WHERE msgid in %s;`, inClause))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -299,6 +215,7 @@ func (mysql *MySQL) deleteHistoryIDs(ctx context.Context, ids []uint64) (err err
|
|||
|
||||
func (mysql *MySQL) selectCleanupIDs(ctx context.Context, age time.Duration) (ids []uint64, maxNanotime int64, err error) {
|
||||
before := timestampSnowflake(time.Now().Add(-age))
|
||||
maxNanotime = time.Now().Add(-age).Unix() * 1000000000
|
||||
|
||||
rows, err := mysql.db.QueryContext(ctx, `
|
||||
SELECT history.msgid
|
||||
|
|
@ -310,8 +227,7 @@ func (mysql *MySQL) selectCleanupIDs(ctx context.Context, age time.Duration) (id
|
|||
}
|
||||
defer rows.Close()
|
||||
|
||||
idset := make(map[uint64]struct{}, cleanupRowLimit)
|
||||
ids = make([]uint64, len(idset))
|
||||
ids = make([]uint64, cleanupRowLimit)
|
||||
|
||||
i := 0
|
||||
for rows.Next() {
|
||||
|
|
@ -323,6 +239,7 @@ func (mysql *MySQL) selectCleanupIDs(ctx context.Context, age time.Duration) (id
|
|||
ids[i] = id
|
||||
i++
|
||||
}
|
||||
ids = ids[0:i]
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -441,6 +358,38 @@ func (mysql *MySQL) prepareStatements() (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
mysql.getReactionsQuery, err = mysql.db.Prepare(`select react, count(*) as total, (select JSON_ARRAYAGG(user)
|
||||
from reactions r
|
||||
where r.msgid = main.msgid
|
||||
and r.react = main.react
|
||||
limit 3) as sample
|
||||
from reactions as main
|
||||
where main.msgid = ?
|
||||
group by react, msgid;`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
mysql.getSingleReaction, err = mysql.db.Prepare(`SELECT COUNT(*) FROM reactions WHERE msgid = ? AND user = ? AND react = ?`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
mysql.deleteReaction, err = mysql.db.Prepare(`DELETE FROM reactions WHERE msgid = ? AND user = ? AND react = ?`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
mysql.addReaction, err = mysql.db.Prepare(`INSERT INTO reactions(msgid, user, react) VALUES (?, ?, ?)`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
mysql.getMessageById, err = mysql.db.Prepare(`SELECT msgid, data, target, sender, nanotime, pm FROM history WHERE msgid = ?`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -649,6 +598,44 @@ func (mysql *MySQL) Export(account string, writer io.Writer) {
|
|||
return*/
|
||||
}
|
||||
|
||||
// Kinda an intermediary function due to the CEF DB structure
|
||||
func (mysql *MySQL) GetMessage(msgid string) (id uint64, item history.Item, target string, sender string, nanotime uint64, pm bool, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
|
||||
defer cancel()
|
||||
var data []byte
|
||||
|
||||
row := mysql.getMessageById.QueryRowContext(ctx, msgid)
|
||||
err = row.Scan(&id, &data, &target, &sender, &nanotime, &pm)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = unmarshalItem(data, &item)
|
||||
return
|
||||
}
|
||||
|
||||
func (mysql *MySQL) HasReactionFromUser(msgid string, user string, reaction string) (exists bool) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
|
||||
defer cancel()
|
||||
row := mysql.getSingleReaction.QueryRowContext(ctx, msgid, user, reaction)
|
||||
var count int
|
||||
row.Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (mysql *MySQL) AddReaction(msgid string, user string, reaction string) (err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
|
||||
defer cancel()
|
||||
_, err = mysql.addReaction.ExecContext(ctx, msgid, user, reaction)
|
||||
return
|
||||
}
|
||||
|
||||
func (mysql *MySQL) DeleteReaction(msgid string, user string, reaction string) (err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
|
||||
defer cancel()
|
||||
_, err = mysql.deleteReaction.ExecContext(ctx, msgid, user, reaction)
|
||||
return
|
||||
}
|
||||
|
||||
func (mysql *MySQL) lookupMsgid(ctx context.Context, msgid string, includeData bool) (result time.Time, id uint64, data []byte, err error) {
|
||||
if err != nil {
|
||||
return
|
||||
|
|
@ -662,7 +649,7 @@ func (mysql *MySQL) lookupMsgid(ctx context.Context, msgid string, includeData b
|
|||
// May have to adjust it some day
|
||||
row := mysql.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT %s FROM history
|
||||
WHERE history.msgid = CAST(? AS INT) LIMIT 1;`, cols), msgid)
|
||||
WHERE history.msgid = CAST(? AS UNSIGNED) LIMIT 1;`, cols), msgid)
|
||||
var nanoSeq sql.NullInt64
|
||||
if !includeData {
|
||||
err = row.Scan(&nanoSeq)
|
||||
|
|
@ -694,8 +681,10 @@ func (mysql *MySQL) selectItems(ctx context.Context, query string, args ...inter
|
|||
|
||||
for rows.Next() {
|
||||
var blob []byte
|
||||
var msgid uint64
|
||||
var item history.Item
|
||||
err = rows.Scan(&blob)
|
||||
|
||||
err = rows.Scan(&blob, &msgid)
|
||||
if mysql.logError("could not scan history item", err) {
|
||||
return
|
||||
}
|
||||
|
|
@ -703,6 +692,25 @@ func (mysql *MySQL) selectItems(ctx context.Context, query string, args ...inter
|
|||
if mysql.logError("could not unmarshal history item", err) {
|
||||
return
|
||||
}
|
||||
reactions, rErr := mysql.getReactionsQuery.Query(msgid)
|
||||
if mysql.logError("could not get reactions", rErr) {
|
||||
return
|
||||
}
|
||||
var react string
|
||||
var total int
|
||||
var sample string
|
||||
for reactions.Next() {
|
||||
reactions.Scan(&react, &total, &sample)
|
||||
var sampleDecoded []string
|
||||
json.Unmarshal([]byte(sample), &sampleDecoded)
|
||||
|
||||
item.Reactions = append(item.Reactions, history.Reaction{
|
||||
Name: react,
|
||||
Total: total,
|
||||
SampleUsers: sampleDecoded,
|
||||
})
|
||||
}
|
||||
|
||||
results = append(results, item)
|
||||
}
|
||||
return
|
||||
|
|
@ -722,13 +730,12 @@ func (mysql *MySQL) betweenTimestamps(ctx context.Context, target, correspondent
|
|||
}
|
||||
|
||||
var queryBuf strings.Builder
|
||||
|
||||
args := make([]interface{}, 0, 7)
|
||||
if correspondent == "" {
|
||||
fmt.Fprintf(&queryBuf, "SELECT history.data from history WHERE target = ? ")
|
||||
fmt.Fprintf(&queryBuf, "SELECT data, msgid FROM history WHERE target = ? ")
|
||||
args = append(args, target)
|
||||
} else {
|
||||
fmt.Fprintf(&queryBuf, "SELECT history.data from history WHERE (target = ? and sender = ?) OR (target = ? and sender = ?)")
|
||||
fmt.Fprintf(&queryBuf, "SELECT data, msgid FROM history WHERE (target = ? and sender = ?) OR (target = ? and sender = ?)")
|
||||
args = append(args, target, correspondent, correspondent, target)
|
||||
}
|
||||
|
||||
|
|
@ -737,7 +744,7 @@ func (mysql *MySQL) betweenTimestamps(ctx context.Context, target, correspondent
|
|||
args = append(args, after.UnixNano())
|
||||
}
|
||||
if !before.IsZero() {
|
||||
fmt.Fprintf(&queryBuf, " AND nanotime < ?")
|
||||
fmt.Fprintf(&queryBuf, " AND nanotime <= ?")
|
||||
args = append(args, before.UnixNano())
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue