// Copyright (c) 2020 Shivaram Lingamneni // released under the MIT license package mysql import ( "context" "database/sql" "encoding/json" "errors" "fmt" "io" "runtime/debug" "slices" "strings" "sync" "sync/atomic" "time" "github.com/ergochat/ergo/irc/history" "github.com/ergochat/ergo/irc/logger" "github.com/ergochat/ergo/irc/utils" _ "github.com/go-sql-driver/mysql" ) var ( ErrDisallowed = errors.New("disallowed") ) const ( // maximum length in bytes of any message target (nickname or channel name) in its // canonicalized (i.e., casefolded) state: MaxTargetLength = 64 cleanupRowLimit = 50 cleanupPauseTime = 10 * time.Minute ) type e struct{} type MySQL struct { db *sql.DB logger *logger.Manager insertHistory *sql.Stmt insertConversation *sql.Stmt insertAccountMessage *sql.Stmt getReactionsQuery *sql.Stmt getSingleReaction *sql.Stmt addReaction *sql.Stmt deleteReaction *sql.Stmt getMessageById *sql.Stmt stateMutex sync.Mutex config Config wakeForgetter chan e timeout atomic.Uint64 trackAccountMessages atomic.Uint32 } func (mysql *MySQL) Initialize(logger *logger.Manager, config Config) { mysql.logger = logger mysql.wakeForgetter = make(chan e, 1) mysql.SetConfig(config) } func (mysql *MySQL) SetConfig(config Config) { mysql.timeout.Store(uint64(config.Timeout)) var trackAccountMessages uint32 if config.TrackAccountMessages { trackAccountMessages = 1 } mysql.trackAccountMessages.Store(trackAccountMessages) mysql.stateMutex.Lock() mysql.config = config mysql.stateMutex.Unlock() } func (mysql *MySQL) getExpireTime() (expireTime time.Duration) { mysql.stateMutex.Lock() expireTime = mysql.config.ExpireTime mysql.stateMutex.Unlock() return } func (mysql *MySQL) Open() (err error) { var address string if mysql.config.SocketPath != "" { address = fmt.Sprintf("unix(%s)", mysql.config.SocketPath) } else if mysql.config.Port != 0 { address = fmt.Sprintf("tcp(%s:%d)", mysql.config.Host, mysql.config.Port) } mysql.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@%s/%s", mysql.config.User, mysql.config.Password, address, mysql.config.HistoryDatabase)) if err != nil { return err } if mysql.config.MaxConns != 0 { mysql.db.SetMaxOpenConns(mysql.config.MaxConns) mysql.db.SetMaxIdleConns(mysql.config.MaxConns) } if mysql.config.ConnMaxLifetime != 0 { mysql.db.SetConnMaxLifetime(mysql.config.ConnMaxLifetime) } err = mysql.fixSchemas() if err != nil { return err } err = mysql.prepareStatements() if err != nil { return err } go mysql.cleanupLoop() go mysql.forgetLoop() return nil } func (mysql *MySQL) fixSchemas() (err error) { // 3M now handles this return } func (mysql *MySQL) createTables() (err error) { // 3M now handles this return nil } func (mysql *MySQL) createComplianceTables() (err error) { // 3M now handles this return nil } func (mysql *MySQL) cleanupLoop() { defer func() { if r := recover(); r != nil { mysql.logger.Error("mysql", fmt.Sprintf("Panic in cleanup routine: %v\n%s", r, debug.Stack())) time.Sleep(cleanupPauseTime) go mysql.cleanupLoop() } }() for { expireTime := mysql.getExpireTime() if expireTime != 0 { for { startTime := time.Now() rowsDeleted, err := mysql.doCleanup(expireTime) elapsed := time.Now().Sub(startTime) mysql.logError("error during row cleanup", err) // keep going as long as we're accomplishing significant work // (don't busy-wait on small numbers of rows expiring): if rowsDeleted < (cleanupRowLimit / 10) { break } // crude backpressure mechanism: if the database is slow, // give it time to process other queries time.Sleep(elapsed) } } time.Sleep(cleanupPauseTime) } } func (mysql *MySQL) doCleanup(age time.Duration) (count int, err error) { ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime) defer cancel() ids, maxNanotime, err := mysql.selectCleanupIDs(ctx, age) if len(ids) == 0 { mysql.logger.Debug("mysql", "found no rows to clean up") return } mysql.logger.Debug("mysql", fmt.Sprintf("deleting %d history rows, max age %s", len(ids), utils.NanoToTimestamp(maxNanotime))) return len(ids), mysql.deleteHistoryIDs(ctx, ids) } func (mysql *MySQL) deleteHistoryIDs(ctx context.Context, ids []uint64) (err error) { // can't use ? binding for a variable number of arguments, build the IN clause manually var inBuf strings.Builder inBuf.WriteByte('(') for i, id := range ids { if i != 0 { inBuf.WriteRune(',') } fmt.Fprintf(&inBuf, "%d", id) } inBuf.WriteRune(')') inClause := inBuf.String() if mysql.isTrackingAccountMessages() { _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM account_messages WHERE history_id in %s;`, inClause)) if err != nil { return } } fmt.Printf(`DELETE FROM history WHERE msgid in %s;`, inClause) _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM history WHERE msgid in %s;`, inClause)) if err != nil { return } return } func (mysql *MySQL) selectCleanupIDs(ctx context.Context, age time.Duration) (ids []uint64, maxNanotime int64, err error) { before := timestampSnowflake(time.Now().Add(-age)) maxNanotime = time.Now().Add(-age).Unix() * 1000000000 rows, err := mysql.db.QueryContext(ctx, ` SELECT history.msgid FROM history WHERE msgid < ? ORDER BY history.msgid LIMIT ?;`, before, cleanupRowLimit) if err != nil { return } defer rows.Close() ids = make([]uint64, cleanupRowLimit) i := 0 for rows.Next() { var id uint64 err = rows.Scan(&id) if err != nil { return } ids[i] = id i++ } ids = ids[0:i] return } // wait for forget queue items and process them one by one func (mysql *MySQL) forgetLoop() { defer func() { if r := recover(); r != nil { mysql.logger.Error("mysql", fmt.Sprintf("Panic in forget routine: %v\n%s", r, debug.Stack())) time.Sleep(cleanupPauseTime) go mysql.forgetLoop() } }() for { for { found, err := mysql.doForget() mysql.logError("error processing forget", err) if err != nil { time.Sleep(cleanupPauseTime) } if !found { break } } <-mysql.wakeForgetter } } // dequeue an item from the forget queue and process it func (mysql *MySQL) doForget() (found bool, err error) { id, account, err := func() (id int64, account string, err error) { ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime) defer cancel() row := mysql.db.QueryRowContext(ctx, `SELECT forget.id, forget.account FROM forget LIMIT 1;`) err = row.Scan(&id, &account) if err == sql.ErrNoRows { return 0, "", nil } return }() if err != nil || account == "" { return false, err } found = true var count int for { start := time.Now() count, err = mysql.doForgetIteration(account) elapsed := time.Since(start) if err != nil { return true, err } if count == 0 { break } time.Sleep(elapsed) } mysql.logger.Debug("mysql", "forget complete for account", account) ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime) defer cancel() _, err = mysql.db.ExecContext(ctx, `DELETE FROM forget where id = ?;`, id) return } func (mysql *MySQL) doForgetIteration(account string) (count int, err error) { ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime) defer cancel() rows, err := mysql.db.QueryContext(ctx, ` SELECT account_messages.history_id FROM account_messages WHERE account_messages.account = ? LIMIT ?;`, account, cleanupRowLimit) if err != nil { return } defer rows.Close() var ids []uint64 for rows.Next() { var id uint64 err = rows.Scan(&id) if err != nil { return } ids = append(ids, id) } if len(ids) == 0 { return } mysql.logger.Debug("mysql", fmt.Sprintf("deleting %d history rows from account %s", len(ids), account)) err = mysql.deleteHistoryIDs(ctx, ids) return len(ids), err } func (mysql *MySQL) prepareStatements() (err error) { mysql.insertHistory, err = mysql.db.Prepare(`INSERT INTO history (data, msgid, target, sender, nanotime) VALUES (?, ?, ?, ?, ?);`) if err != nil { return } mysql.insertAccountMessage, err = mysql.db.Prepare(`INSERT INTO account_messages (history_id, account) VALUES (?, ?);`) if err != nil { return } mysql.getReactionsQuery, err = mysql.db.Prepare(`select react, count(*) as total, (select JSON_ARRAYAGG(user) from reactions r where r.msgid = main.msgid and r.react = main.react limit 3) as sample from reactions as main where main.msgid = ? group by react, msgid;`) if err != nil { return } mysql.getSingleReaction, err = mysql.db.Prepare(`SELECT COUNT(*) FROM reactions WHERE msgid = ? AND user = ? AND react = ?`) if err != nil { return } mysql.deleteReaction, err = mysql.db.Prepare(`DELETE FROM reactions WHERE msgid = ? AND user = ? AND react = ?`) if err != nil { return } mysql.addReaction, err = mysql.db.Prepare(`INSERT INTO reactions(msgid, user, react) VALUES (?, ?, ?)`) if err != nil { return } mysql.getMessageById, err = mysql.db.Prepare(`SELECT msgid, data, target, sender, nanotime, pm FROM history WHERE msgid = ?`) if err != nil { return } return } func (mysql *MySQL) getTimeout() time.Duration { return time.Duration(mysql.timeout.Load()) } func (mysql *MySQL) isTrackingAccountMessages() bool { return mysql.trackAccountMessages.Load() != 0 } func (mysql *MySQL) logError(context string, err error) (quit bool) { if err != nil { mysql.logger.Error("mysql", context, err.Error()) return true } return false } func (mysql *MySQL) Forget(account string) { if mysql.db == nil || account == "" { return } ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout()) defer cancel() _, err := mysql.db.ExecContext(ctx, `INSERT INTO forget (account) VALUES (?);`, account) if mysql.logError("can't insert into forget table", err) { return } // wake up the forget goroutine if it's blocked: select { case mysql.wakeForgetter <- e{}: default: } } func (mysql *MySQL) AddChannelItem(target string, item history.Item, account string) (err error) { if mysql.db == nil { return } if target == "" { return utils.ErrInvalidParams } ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout()) defer cancel() id, err := mysql.insertBase(ctx, item) if err != nil { return } err = mysql.insertAccountMessageEntry(ctx, id, account) if err != nil { return } return } func (mysql *MySQL) insertBase(ctx context.Context, item history.Item) (id int64, err error) { var value []byte value, err = marshalItem(&item) if mysql.logError("could not marshal item", err) { return } var account = item.Account if account == "" { account = "*" } result, err := mysql.insertHistory.ExecContext(ctx, value, item.Message.Msgid, item.Target, account, item.Message.Time.UnixNano()) if mysql.logError("could not insert item", err) { return } id, err = result.LastInsertId() if mysql.logError("could not insert item", err) { return } return } func (mysql *MySQL) insertAccountMessageEntry(ctx context.Context, id int64, account string) (err error) { if account == "" || !mysql.isTrackingAccountMessages() { return } _, err = mysql.insertAccountMessage.ExecContext(ctx, id, account) mysql.logError("could not insert account-message entry", err) return } func (mysql *MySQL) AddDirectMessage(sender, senderAccount, recipient, recipientAccount string, item history.Item) (err error) { if mysql.db == nil { return } if senderAccount == "" && recipientAccount == "" { return } if sender == "" || recipient == "" { return utils.ErrInvalidParams } ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout()) defer cancel() _, err = mysql.insertBase(ctx, item) if err != nil { return } return } // note that accountName is the unfolded name func (mysql *MySQL) DeleteMsgid(msgid, account string) (err error) { if mysql.db == nil { return nil } ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout()) defer cancel() _, id, data, err := mysql.lookupMsgid(ctx, msgid, true) if err != nil { return } if account != "*" { var item history.Item err = unmarshalItem(data, &item) // delete if the entry is corrupt if err == nil && item.Account != account { return ErrDisallowed } } err = mysql.deleteHistoryIDs(ctx, []uint64{id}) mysql.logError("couldn't delete msgid", err) return } func (mysql *MySQL) Export(account string, writer io.Writer) { // no eu presence... // maybe fix this when i know the new schema works return /*if mysql.db == nil { return } var err error var lastSeen uint64 for { rows := func() (count int) { ctx, cancel := context.WithTimeout(context.Background(), cleanupPauseTime) defer cancel() rows, rowsErr := mysql.db.QueryContext(ctx, ` SELECT history.data, msgid, target FROM history WHERE sender = ? AND account_messages.history_id > ? LIMIT ?`, account, lastSeen, cleanupRowLimit) if rowsErr != nil { err = rowsErr return } defer rows.Close() for rows.Next() { var id uint64 var blob, jsonBlob []byte var target string var item history.Item err = rows.Scan(&blob, &id, &target) if err != nil { return } err = unmarshalItem(blob, &item) if err != nil { return } item.Target = target jsonBlob, err = json.Marshal(item) if err != nil { return } count++ if lastSeen < id { lastSeen = id } writer.Write(jsonBlob) writer.Write([]byte{'\n'}) } return }() if rows == 0 || err != nil { break } } mysql.logError("could not export history", err) return*/ } // Kinda an intermediary function due to the CEF DB structure func (mysql *MySQL) GetMessage(msgid string) (id uint64, item history.Item, target string, sender string, nanotime uint64, pm bool, err error) { ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout()) defer cancel() var data []byte row := mysql.getMessageById.QueryRowContext(ctx, msgid) err = row.Scan(&id, &data, &target, &sender, &nanotime, &pm) if err != nil { return } err = unmarshalItem(data, &item) return } func (mysql *MySQL) HasReactionFromUser(msgid string, user string, reaction string) (exists bool) { ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout()) defer cancel() row := mysql.getSingleReaction.QueryRowContext(ctx, msgid, user, reaction) var count int row.Scan(&count) return count > 0 } func (mysql *MySQL) AddReaction(msgid string, user string, reaction string) (err error) { ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout()) defer cancel() _, err = mysql.addReaction.ExecContext(ctx, msgid, user, reaction) return } func (mysql *MySQL) DeleteReaction(msgid string, user string, reaction string) (err error) { ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout()) defer cancel() _, err = mysql.deleteReaction.ExecContext(ctx, msgid, user, reaction) return } func (mysql *MySQL) lookupMsgid(ctx context.Context, msgid string, includeData bool) (result time.Time, id uint64, data []byte, err error) { if err != nil { return } cols := `history.nanotime` if includeData { cols = `history.nanotime, history.id, history.data` } // Since CEF uses snowflakes and vanilla ergo uses blobs, we cast as int to make it function. // May have to adjust it some day row := mysql.db.QueryRowContext(ctx, fmt.Sprintf(` SELECT %s FROM history WHERE history.msgid = CAST(? AS UNSIGNED) LIMIT 1;`, cols), msgid) var nanoSeq sql.NullInt64 if !includeData { err = row.Scan(&nanoSeq) } else { err = row.Scan(&nanoSeq, &id, &data) } if err != sql.ErrNoRows { mysql.logError("could not resolve msgid to time", err) } if err != nil { return } nanotime := nanoSeq.Int64 if nanotime == 0 { err = sql.ErrNoRows return } result = time.Unix(0, nanotime).UTC() return } func (mysql *MySQL) selectItems(ctx context.Context, query string, args ...interface{}) (results []history.Item, err error) { rows, err := mysql.db.QueryContext(ctx, query, args...) if mysql.logError("could not select history items", err) { return } defer rows.Close() for rows.Next() { var blob []byte var msgid uint64 var item history.Item err = rows.Scan(&blob, &msgid) if mysql.logError("could not scan history item", err) { return } err = unmarshalItem(blob, &item) if mysql.logError("could not unmarshal history item", err) { return } reactions, rErr := mysql.getReactionsQuery.Query(msgid) if mysql.logError("could not get reactions", rErr) { return } var react string var total int var sample string for reactions.Next() { reactions.Scan(&react, &total, &sample) var sampleDecoded []string json.Unmarshal([]byte(sample), &sampleDecoded) item.Reactions = append(item.Reactions, history.Reaction{ Name: react, Total: total, SampleUsers: sampleDecoded, }) } results = append(results, item) } return } func timestampSnowflake(t time.Time) uint64 { var ts = t.Unix() & 0xffffffffffff return uint64(ts << 16) } func (mysql *MySQL) betweenTimestamps(ctx context.Context, target, correspondent string, after, before, cutoff time.Time, limit int) (results []history.Item, err error) { after, before, ascending := history.MinMaxAsc(after, before, cutoff) direction := "ASC" if !ascending { direction = "DESC" } var queryBuf strings.Builder args := make([]interface{}, 0, 7) if correspondent == "" { fmt.Fprintf(&queryBuf, "SELECT data, msgid FROM history WHERE target = ? ") args = append(args, target) } else { fmt.Fprintf(&queryBuf, "SELECT data, msgid FROM history WHERE (target = ? and sender = ?) OR (target = ? and sender = ?)") args = append(args, target, correspondent, correspondent, target) } if !after.IsZero() { fmt.Fprintf(&queryBuf, " AND nanotime > ?") args = append(args, after.UnixNano()) } if !before.IsZero() { fmt.Fprintf(&queryBuf, " AND nanotime <= ?") args = append(args, before.UnixNano()) } fmt.Fprintf(&queryBuf, " ORDER BY nanotime %[1]s LIMIT ?;", direction) args = append(args, limit) results, err = mysql.selectItems(ctx, queryBuf.String(), args...) if err == nil && !ascending { slices.Reverse(results) } return } func (mysql *MySQL) listCorrespondentsInternal(ctx context.Context, target string, after, before, cutoff time.Time, limit int) (results []history.TargetListing, err error) { after, before, ascending := history.MinMaxAsc(after, before, cutoff) direction := "ASC" if !ascending { direction = "DESC" } var queryBuf strings.Builder args := make([]interface{}, 0, 5) queryBuf.WriteString(`SELECT target, sender, nanotime from history WHERE target = ? OR (sender = ? and pm = true)`) args = append(args, target, target) if !after.IsZero() { queryBuf.WriteString(" AND nanotime > ?") args = append(args, after.UnixNano()) } if !before.IsZero() { queryBuf.WriteString(" AND nanotime < ?") args = append(args, before.UnixNano()) } fmt.Fprintf(&queryBuf, " ORDER BY nanotime %s LIMIT ?;", direction) args = append(args, limit) query := queryBuf.String() rows, err := mysql.db.QueryContext(ctx, query, args...) if err != nil { return } defer rows.Close() var msgTarget string var msgSender string var nanotime int64 for rows.Next() { err = rows.Scan(&msgTarget, &msgSender, &nanotime) if err != nil { return } if msgTarget == target { results = append(results, history.TargetListing{ CfName: msgSender, Time: time.Unix(0, nanotime), }) } else { results = append(results, history.TargetListing{ CfName: msgTarget, Time: time.Unix(0, nanotime), }) } } if !ascending { slices.Reverse(results) } return } func (mysql *MySQL) ListChannels(cfchannels []string) (results []history.TargetListing, err error) { if mysql.db == nil { return } if len(cfchannels) == 0 { return } ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout()) defer cancel() var queryBuf strings.Builder args := make([]interface{}, 0, len(results)) // https://dev.mysql.com/doc/refman/8.0/en/group-by-optimization.html // this should be a "loose index scan" queryBuf.WriteString(`SELECT sequence.target, MAX(sequence.nanotime) FROM sequence WHERE sequence.target IN (`) for i, chname := range cfchannels { if i != 0 { queryBuf.WriteString(", ") } queryBuf.WriteByte('?') args = append(args, chname) } queryBuf.WriteString(") GROUP BY sequence.target;") rows, err := mysql.db.QueryContext(ctx, queryBuf.String(), args...) if mysql.logError("could not query channel listings", err) { return } defer rows.Close() var target string var nanotime int64 for rows.Next() { err = rows.Scan(&target, &nanotime) if mysql.logError("could not scan channel listings", err) { return } results = append(results, history.TargetListing{ CfName: target, Time: time.Unix(0, nanotime), }) } return } func (mysql *MySQL) Close() { // closing the database will close our prepared statements as well if mysql.db != nil { mysql.db.Close() } mysql.db = nil } // implements history.Sequence, emulating a single history buffer (for a channel, // a single user's DMs, or a DM conversation) type mySQLHistorySequence struct { mysql *MySQL target string correspondent string cutoff time.Time } func (s *mySQLHistorySequence) Between(start, end history.Selector, limit int) (results []history.Item, err error) { ctx, cancel := context.WithTimeout(context.Background(), s.mysql.getTimeout()) defer cancel() startTime := start.Time if start.Msgid != "" { startTime, _, _, err = s.mysql.lookupMsgid(ctx, start.Msgid, false) if err != nil { if err == sql.ErrNoRows { return nil, nil } else { return nil, err } } } endTime := end.Time if end.Msgid != "" { endTime, _, _, err = s.mysql.lookupMsgid(ctx, end.Msgid, false) if err != nil { if err == sql.ErrNoRows { return nil, nil } else { return nil, err } } } results, err = s.mysql.betweenTimestamps(ctx, s.target, s.correspondent, startTime, endTime, s.cutoff, limit) return results, err } func (s *mySQLHistorySequence) Around(start history.Selector, limit int) (results []history.Item, err error) { return history.GenericAround(s, start, limit) } func (seq *mySQLHistorySequence) ListCorrespondents(start, end history.Selector, limit int) (results []history.TargetListing, err error) { ctx, cancel := context.WithTimeout(context.Background(), seq.mysql.getTimeout()) defer cancel() // TODO accept msgids here? startTime := start.Time endTime := end.Time results, err = seq.mysql.listCorrespondentsInternal(ctx, seq.target, startTime, endTime, seq.cutoff, limit) seq.mysql.logError("could not read correspondents", err) return } func (seq *mySQLHistorySequence) Cutoff() time.Time { return seq.cutoff } func (seq *mySQLHistorySequence) Ephemeral() bool { return false } func (mysql *MySQL) MakeSequence(target, correspondent string, cutoff time.Time) history.Sequence { return &mySQLHistorySequence{ target: target, correspondent: correspondent, mysql: mysql, cutoff: cutoff, } } func (mysql *MySQL) GetPMs(casefoldedUser string) (results map[string]int64, err error) { if mysql.db == nil { return } results = make(map[string]int64) ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout()) defer cancel() var queryBuf strings.Builder args := make([]interface{}, 0) queryBuf.WriteString(`SELECT max(nanotime), target, sender FROM history WHERE target = ? OR (sender = ? and pm = true) GROUP BY target, sender;`) args = append(args, casefoldedUser, casefoldedUser) rows, err := mysql.db.QueryContext(ctx, queryBuf.String(), args...) if mysql.logError("could not get pms", err) { return } defer rows.Close() var last int64 var target, sender string for rows.Next() { err = rows.Scan(&last, &target, &sender) if mysql.logError("could not get pms", err) { return } // We really don't need nanosecond precision if target != casefoldedUser { results[target] = last / 1000000 } else { results[sender] = last / 1000000 } } return }