diff --git a/irc/cef.go b/irc/cef.go index c7aeac50..6d77a395 100644 --- a/irc/cef.go +++ b/irc/cef.go @@ -34,9 +34,13 @@ func (client *Client) RedisBroadcast(message ...string) { } func (channel *Channel) Broadcast(command string, params ...string) { + channel.BroadcastFrom(channel.server.name, command, params...) +} + +func (channel *Channel) BroadcastFrom(prefix string, command string, params ...string) { for _, member := range channel.Members() { for _, session := range member.Sessions() { - session.Send(nil, member.server.name, command, params...) + session.Send(nil, prefix, command, params...) } } } @@ -153,7 +157,6 @@ func (server *Server) GetUrlMime(url string) string { return "" } contentType, valid := meta["format"].(string) - fmt.Printf("%+v\n", meta) if !valid { println("No content type") return "" diff --git a/irc/channel.go b/irc/channel.go index 07ee7be5..780c1246 100644 --- a/irc/channel.go +++ b/irc/channel.go @@ -727,6 +727,9 @@ func (channel *Channel) AddHistoryItem(item history.Item, account string) (err e if !itemIsStorable(&item, channel.server.Config()) { return } + if item.Target == "" { + item.Target = channel.nameCasefolded + } status, target, _ := channel.historyStatus(channel.server.Config()) if status == HistoryPersistent { @@ -1091,9 +1094,9 @@ func (channel *Channel) replayHistoryItems(rb *ResponseBuffer, items []history.I nick := NUHToNick(item.Nick) switch item.Type { case history.Privmsg: - rb.AddSplitMessageFromClient(item.Nick, item.Account, item.IsBot, item.Tags, "PRIVMSG", chname, item.Message) + rb.AddSplitMessageFromClientWithReactions(item.Nick, item.Account, item.IsBot, item.Tags, "PRIVMSG", chname, item.Message, item.Reactions) case history.Notice: - rb.AddSplitMessageFromClient(item.Nick, item.Account, item.IsBot, item.Tags, "NOTICE", chname, item.Message) + rb.AddSplitMessageFromClientWithReactions(item.Nick, item.Account, item.IsBot, item.Tags, "NOTICE", chname, item.Message, item.Reactions) case history.Tagmsg: if eventPlayback { rb.AddSplitMessageFromClient(item.Nick, item.Account, item.IsBot, item.Tags, "TAGMSG", chname, item.Message) diff --git a/irc/client.go b/irc/client.go index e8487732..763144db 100644 --- a/irc/client.go +++ b/irc/client.go @@ -931,11 +931,11 @@ func (client *Client) replayPrivmsgHistory(rb *ResponseBuffer, items []history.I tags = item.Tags } if !isSelfMessage(&item) { - rb.AddSplitMessageFromClient(item.Nick, item.Account, item.IsBot, tags, command, nick, item.Message) + rb.AddSplitMessageFromClientWithReactions(item.Nick, item.Account, item.IsBot, tags, command, nick, item.Message, item.Reactions) } else { // this message was sent *from* the client to another nick; the target is item.Params[0] // substitute client's current nickmask in case client changed nick - rb.AddSplitMessageFromClient(details.nickMask, item.Account, item.IsBot, tags, command, item.Params[0], item.Message) + rb.AddSplitMessageFromClientWithReactions(details.nickMask, item.Account, item.IsBot, tags, command, item.Params[0], item.Message, item.Reactions) } } diff --git a/irc/commands.go b/irc/commands.go index 4bd88dd7..6757da5e 100644 --- a/irc/commands.go +++ b/irc/commands.go @@ -383,6 +383,11 @@ func init() { handler: zncHandler, minParams: 1, }, + // CEF custom commands + "REACT": { + handler: reactHandler, + minParams: 2, + }, } initializeServices() diff --git a/irc/handlers.go b/irc/handlers.go index 984a99b4..57d8662b 100644 --- a/irc/handlers.go +++ b/irc/handlers.go @@ -4085,6 +4085,33 @@ func zncHandler(server *Server, client *Client, msg ircmsg.Message, rb *Response return false } +// REACT : +func reactHandler(server *Server, client *Client, msg ircmsg.Message, rb *ResponseBuffer) bool { + // This directly uses SQL stuff, since it's targeted at CEF, which requires a DB. + _, _, target, sender, _, pm, err := server.historyDB.GetMessage(msg.Params[0]) + if err != nil { + return false + } + + var operation string + if server.historyDB.HasReactionFromUser(msg.Params[0], client.AccountName(), msg.Params[1]) { + server.historyDB.DeleteReaction(msg.Params[0], client.AccountName(), msg.Params[1]) + operation = "DEL" + } else { + server.historyDB.AddReaction(msg.Params[0], client.AccountName(), msg.Params[1]) + operation = "ADD" + } + + if pm { + server.clients.Get(target).Send(nil, client.NickMaskString(), "REACT", operation, msg.Params[0], msg.Params[1]) + server.clients.Get(sender).Send(nil, client.NickMaskString(), "REACT", operation, msg.Params[0], msg.Params[1]) + } else { + server.channels.Get(target).BroadcastFrom(client.NickMaskString(), "REACT", operation, msg.Params[0], msg.Params[1]) + } + + return false +} + // fake handler for unknown commands func unknownCommandHandler(server *Server, client *Client, msg ircmsg.Message, rb *ResponseBuffer) bool { var message string diff --git a/irc/help.go b/irc/help.go index a11bb725..eb8c056e 100644 --- a/irc/help.go +++ b/irc/help.go @@ -634,6 +634,12 @@ for direct use by end users.`, duplicate: true, }, + "react": { + text: `REACT + +Toggles a reaction to a message. CEF-specific`, + }, + // Informational "modes": { textGenerator: modesTextGenerator, diff --git a/irc/history/history.go b/irc/history/history.go index 7b58337b..cf0c09dc 100644 --- a/irc/history/history.go +++ b/irc/history/history.go @@ -32,6 +32,12 @@ const ( initialAutoSize = 32 ) +type Reaction struct { + Name string + Total int + SampleUsers []string +} + // Item represents an event (e.g., a PRIVMSG or a JOIN) and its associated data type Item struct { Type ItemType @@ -49,6 +55,8 @@ type Item struct { // required by CHATHISTORY: Target string `json:"Target"` IsBot bool `json:"IsBot,omitempty"` + + Reactions []Reaction } // HasMsgid tests whether a message has the message id `msgid`. diff --git a/irc/mysql/history.go b/irc/mysql/history.go index 8f324c04..3d3577db 100644 --- a/irc/mysql/history.go +++ b/irc/mysql/history.go @@ -6,6 +6,7 @@ package mysql import ( "context" "database/sql" + "encoding/json" "errors" "fmt" "io" @@ -29,17 +30,9 @@ var ( const ( // maximum length in bytes of any message target (nickname or channel name) in its // canonicalized (i.e., casefolded) state: - MaxTargetLength = 64 - - // latest schema of the db - latestDbSchema = "2" - keySchemaVersion = "db.version" - // minor version indicates rollback-safe upgrades, i.e., - // you can downgrade oragono and everything will work - latestDbMinorVersion = "2" - keySchemaMinorVersion = "db.minorversion" - cleanupRowLimit = 50 - cleanupPauseTime = 10 * time.Minute + MaxTargetLength = 64 + cleanupRowLimit = 50 + cleanupPauseTime = 10 * time.Minute ) type e struct{} @@ -52,6 +45,13 @@ type MySQL struct { insertConversation *sql.Stmt insertAccountMessage *sql.Stmt + getReactionsQuery *sql.Stmt + getSingleReaction *sql.Stmt + addReaction *sql.Stmt + deleteReaction *sql.Stmt + + getMessageById *sql.Stmt + stateMutex sync.Mutex config Config @@ -124,102 +124,17 @@ func (mysql *MySQL) Open() (err error) { } func (mysql *MySQL) fixSchemas() (err error) { - _, err = mysql.db.Exec(`CREATE TABLE IF NOT EXISTS metadata ( - key_name VARCHAR(32) primary key, - value VARCHAR(32) NOT NULL - ) CHARSET=ascii COLLATE=ascii_bin;`) - if err != nil { - return err - } - - var schema string - err = mysql.db.QueryRow(`select value from metadata where key_name = ?;`, keySchemaVersion).Scan(&schema) - if err == sql.ErrNoRows { - err = mysql.createTables() - if err != nil { - return - } - _, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaVersion, latestDbSchema) - if err != nil { - return - } - _, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaMinorVersion, latestDbMinorVersion) - if err != nil { - return - } - return - } else if err == nil && schema != latestDbSchema { - // TODO figure out what to do about schema changes - return fmt.Errorf("incompatible schema: got %s, expected %s", schema, latestDbSchema) - } else if err != nil { - return err - } - - var minorVersion string - err = mysql.db.QueryRow(`select value from metadata where key_name = ?;`, keySchemaMinorVersion).Scan(&minorVersion) - if err == sql.ErrNoRows { - // XXX for now, the only minor version upgrade is the account tracking tables - err = mysql.createComplianceTables() - if err != nil { - return - } - _, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaMinorVersion, latestDbMinorVersion) - if err != nil { - return - } - } else if err == nil && minorVersion == "1" { - // upgrade from 2.1 to 2.2: create the correspondents table - _, err = mysql.db.Exec(`update metadata set value = ? where key_name = ?;`, latestDbMinorVersion, keySchemaMinorVersion) - if err != nil { - return - } - } else if err == nil && minorVersion != latestDbMinorVersion { - // TODO: if minorVersion < latestDbMinorVersion, upgrade, - // if latestDbMinorVersion < minorVersion, ignore because backwards compatible - } + // 3M now handles this return } func (mysql *MySQL) createTables() (err error) { - _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE history ( - msgid BINARY(16) NOT NULL PRIMARY KEY, - data BLOB NOT NULL, - target VARBINARY(%[1]d) NOT NULL, - sender VARBINARY(%[1]d) NOT NULL, - nanotime BIGINT UNSIGNED NOT NULL, - pm boolean as (SUBSTRING(target, 1, 1) != "#") PERSISTENT, - KEY (msgid(4)) - ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength, MaxTargetLength)) - if err != nil { - return err - } - - err = mysql.createComplianceTables() - if err != nil { - return err - } - + // 3M now handles this return nil } func (mysql *MySQL) createComplianceTables() (err error) { - _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE account_messages ( - history_id BIGINT UNSIGNED NOT NULL PRIMARY KEY, - account VARBINARY(%[1]d) NOT NULL, - KEY (account, history_id) - ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength)) - if err != nil { - return err - } - - _, err = mysql.db.Exec(fmt.Sprintf(`CREATE TABLE forget ( - id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, - account VARBINARY(%[1]d) NOT NULL - ) CHARSET=ascii COLLATE=ascii_bin;`, MaxTargetLength)) - if err != nil { - return err - } - + // 3M now handles this return nil } @@ -289,7 +204,8 @@ func (mysql *MySQL) deleteHistoryIDs(ctx context.Context, ids []uint64) (err err return } } - _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM history WHERE id in %s;`, inClause)) + fmt.Printf(`DELETE FROM history WHERE msgid in %s;`, inClause) + _, err = mysql.db.ExecContext(ctx, fmt.Sprintf(`DELETE FROM history WHERE msgid in %s;`, inClause)) if err != nil { return } @@ -299,6 +215,7 @@ func (mysql *MySQL) deleteHistoryIDs(ctx context.Context, ids []uint64) (err err func (mysql *MySQL) selectCleanupIDs(ctx context.Context, age time.Duration) (ids []uint64, maxNanotime int64, err error) { before := timestampSnowflake(time.Now().Add(-age)) + maxNanotime = time.Now().Add(-age).Unix() * 1000000000 rows, err := mysql.db.QueryContext(ctx, ` SELECT history.msgid @@ -310,8 +227,7 @@ func (mysql *MySQL) selectCleanupIDs(ctx context.Context, age time.Duration) (id } defer rows.Close() - idset := make(map[uint64]struct{}, cleanupRowLimit) - ids = make([]uint64, len(idset)) + ids = make([]uint64, cleanupRowLimit) i := 0 for rows.Next() { @@ -323,6 +239,7 @@ func (mysql *MySQL) selectCleanupIDs(ctx context.Context, age time.Duration) (id ids[i] = id i++ } + ids = ids[0:i] return } @@ -441,6 +358,38 @@ func (mysql *MySQL) prepareStatements() (err error) { return } + mysql.getReactionsQuery, err = mysql.db.Prepare(`select react, count(*) as total, (select JSON_ARRAYAGG(user) + from reactions r + where r.msgid = main.msgid + and r.react = main.react + limit 3) as sample + from reactions as main + where main.msgid = ? + group by react, msgid;`) + if err != nil { + return + } + + mysql.getSingleReaction, err = mysql.db.Prepare(`SELECT COUNT(*) FROM reactions WHERE msgid = ? AND user = ? AND react = ?`) + if err != nil { + return + } + + mysql.deleteReaction, err = mysql.db.Prepare(`DELETE FROM reactions WHERE msgid = ? AND user = ? AND react = ?`) + if err != nil { + return + } + + mysql.addReaction, err = mysql.db.Prepare(`INSERT INTO reactions(msgid, user, react) VALUES (?, ?, ?)`) + if err != nil { + return + } + + mysql.getMessageById, err = mysql.db.Prepare(`SELECT msgid, data, target, sender, nanotime, pm FROM history WHERE msgid = ?`) + if err != nil { + return + } + return } @@ -649,6 +598,44 @@ func (mysql *MySQL) Export(account string, writer io.Writer) { return*/ } +// Kinda an intermediary function due to the CEF DB structure +func (mysql *MySQL) GetMessage(msgid string) (id uint64, item history.Item, target string, sender string, nanotime uint64, pm bool, err error) { + ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout()) + defer cancel() + var data []byte + + row := mysql.getMessageById.QueryRowContext(ctx, msgid) + err = row.Scan(&id, &data, &target, &sender, &nanotime, &pm) + if err != nil { + return + } + err = unmarshalItem(data, &item) + return +} + +func (mysql *MySQL) HasReactionFromUser(msgid string, user string, reaction string) (exists bool) { + ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout()) + defer cancel() + row := mysql.getSingleReaction.QueryRowContext(ctx, msgid, user, reaction) + var count int + row.Scan(&count) + return count > 0 +} + +func (mysql *MySQL) AddReaction(msgid string, user string, reaction string) (err error) { + ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout()) + defer cancel() + _, err = mysql.addReaction.ExecContext(ctx, msgid, user, reaction) + return +} + +func (mysql *MySQL) DeleteReaction(msgid string, user string, reaction string) (err error) { + ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout()) + defer cancel() + _, err = mysql.deleteReaction.ExecContext(ctx, msgid, user, reaction) + return +} + func (mysql *MySQL) lookupMsgid(ctx context.Context, msgid string, includeData bool) (result time.Time, id uint64, data []byte, err error) { if err != nil { return @@ -662,7 +649,7 @@ func (mysql *MySQL) lookupMsgid(ctx context.Context, msgid string, includeData b // May have to adjust it some day row := mysql.db.QueryRowContext(ctx, fmt.Sprintf(` SELECT %s FROM history - WHERE history.msgid = CAST(? AS INT) LIMIT 1;`, cols), msgid) + WHERE history.msgid = CAST(? AS UNSIGNED) LIMIT 1;`, cols), msgid) var nanoSeq sql.NullInt64 if !includeData { err = row.Scan(&nanoSeq) @@ -694,8 +681,10 @@ func (mysql *MySQL) selectItems(ctx context.Context, query string, args ...inter for rows.Next() { var blob []byte + var msgid uint64 var item history.Item - err = rows.Scan(&blob) + + err = rows.Scan(&blob, &msgid) if mysql.logError("could not scan history item", err) { return } @@ -703,6 +692,25 @@ func (mysql *MySQL) selectItems(ctx context.Context, query string, args ...inter if mysql.logError("could not unmarshal history item", err) { return } + reactions, rErr := mysql.getReactionsQuery.Query(msgid) + if mysql.logError("could not get reactions", rErr) { + return + } + var react string + var total int + var sample string + for reactions.Next() { + reactions.Scan(&react, &total, &sample) + var sampleDecoded []string + json.Unmarshal([]byte(sample), &sampleDecoded) + + item.Reactions = append(item.Reactions, history.Reaction{ + Name: react, + Total: total, + SampleUsers: sampleDecoded, + }) + } + results = append(results, item) } return @@ -722,13 +730,12 @@ func (mysql *MySQL) betweenTimestamps(ctx context.Context, target, correspondent } var queryBuf strings.Builder - args := make([]interface{}, 0, 7) if correspondent == "" { - fmt.Fprintf(&queryBuf, "SELECT history.data from history WHERE target = ? ") + fmt.Fprintf(&queryBuf, "SELECT data, msgid FROM history WHERE target = ? ") args = append(args, target) } else { - fmt.Fprintf(&queryBuf, "SELECT history.data from history WHERE (target = ? and sender = ?) OR (target = ? and sender = ?)") + fmt.Fprintf(&queryBuf, "SELECT data, msgid FROM history WHERE (target = ? and sender = ?) OR (target = ? and sender = ?)") args = append(args, target, correspondent, correspondent, target) } @@ -737,7 +744,7 @@ func (mysql *MySQL) betweenTimestamps(ctx context.Context, target, correspondent args = append(args, after.UnixNano()) } if !before.IsZero() { - fmt.Fprintf(&queryBuf, " AND nanotime < ?") + fmt.Fprintf(&queryBuf, " AND nanotime <= ?") args = append(args, before.UnixNano()) } diff --git a/irc/responsebuffer.go b/irc/responsebuffer.go index 85d5fd3c..e17f2996 100644 --- a/irc/responsebuffer.go +++ b/irc/responsebuffer.go @@ -4,7 +4,10 @@ package irc import ( + "github.com/ergochat/ergo/irc/history" "runtime/debug" + "strconv" + "strings" "time" "github.com/ergochat/ergo/irc/caps" @@ -121,8 +124,12 @@ func (rb *ResponseBuffer) AddFromClient(time time.Time, msgid string, fromNickMa rb.AddMessage(msg) } -// AddSplitMessageFromClient adds a new split message from a specific client to our queue. func (rb *ResponseBuffer) AddSplitMessageFromClient(fromNickMask string, fromAccount string, isBot bool, tags map[string]string, command string, target string, message utils.SplitMessage) { + rb.AddSplitMessageFromClientWithReactions(fromNickMask, fromAccount, isBot, tags, command, target, message, nil) +} + +// AddSplitMessageFromClient adds a new split message from a specific client to our queue. +func (rb *ResponseBuffer) AddSplitMessageFromClientWithReactions(fromNickMask string, fromAccount string, isBot bool, tags map[string]string, command string, target string, message utils.SplitMessage, reactions []history.Reaction) { if message.Is512() { if message.Message == "" { // XXX this is a TAGMSG @@ -153,6 +160,14 @@ func (rb *ResponseBuffer) AddSplitMessageFromClient(fromNickMask string, fromAcc } } } + if reactions != nil && len(reactions) >= 1 { + var text string + for _, react := range reactions { + text = strings.Join([]string{message.Msgid, react.Name, strconv.Itoa(react.Total)}, " ") + text += " " + strings.Join(react.SampleUsers, " ") + rb.Add(nil, rb.target.server.name, "REACTIONS", text) + } + } } func (rb *ResponseBuffer) addEchoMessage(tags map[string]string, nickMask, accountName, command, target string, message utils.SplitMessage) {