implement draft/webpush (#2205)

This commit is contained in:
Shivaram Lingamneni 2025-01-13 18:47:21 -08:00 committed by GitHub
parent efd3764337
commit 36e5451aa5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
44 changed files with 2091 additions and 100 deletions

View file

@ -6,6 +6,7 @@
package irc
import (
"context"
"crypto/x509"
"fmt"
"maps"
@ -32,6 +33,7 @@ import (
"github.com/ergochat/ergo/irc/oauth2"
"github.com/ergochat/ergo/irc/sno"
"github.com/ergochat/ergo/irc/utils"
"github.com/ergochat/ergo/irc/webpush"
)
const (
@ -46,6 +48,10 @@ const (
// maximum total read markers that can be stored
// (writeback of read markers is controlled by lastSeen logic)
maxReadMarkers = 256
// should be long enough to handle multiple notifications in rapid succession,
// short enough that it doesn't waste a lot of RAM per client
pushQueueLengthPerClient = 16
)
const (
@ -71,52 +77,57 @@ var (
// Client is an IRC client.
type Client struct {
account string
accountName string // display name of the account: uncasefolded, '*' if not logged in
accountRegDate time.Time
accountSettings AccountSettings
awayMessage string
channels ChannelSet
ctime time.Time
destroyed bool
modes modes.ModeSet
hostname string
invitedTo map[string]channelInvite
isSTSOnly bool
isKlined bool // #1941: k-line kills are special-cased to suppress some triggered notices/events
languages []string
lastActive time.Time // last time they sent a command that wasn't PONG or similar
lastSeen map[string]time.Time // maps device ID (including "") to time of last received command
readMarkers map[string]time.Time // maps casefolded target to time of last read marker
loginThrottle connection_limits.GenericThrottle
nextSessionID int64 // Incremented when a new session is established
nick string
nickCasefolded string
nickMaskCasefolded string
nickMaskString string // cache for nickmask string since it's used with lots of replies
oper *Oper
preregNick string
proxiedIP net.IP // actual remote IP if using the PROXY protocol
rawHostname string
cloakedHostname string
realname string
realIP net.IP
requireSASLMessage string
requireSASL bool
registered bool
registerCmdSent bool // already sent the draft/register command, can't send it again
dirtyTimestamps bool // lastSeen or readMarkers is dirty
registrationTimer *time.Timer
server *Server
skeleton string
sessions []*Session
stateMutex sync.RWMutex // tier 1
alwaysOn bool
username string
vhost string
history history.Buffer
dirtyBits uint
writebackLock sync.Mutex // tier 1.5
account string
accountName string // display name of the account: uncasefolded, '*' if not logged in
accountRegDate time.Time
accountSettings AccountSettings
awayMessage string
channels ChannelSet
ctime time.Time
destroyed bool
modes modes.ModeSet
hostname string
invitedTo map[string]channelInvite
isSTSOnly bool
isKlined bool // #1941: k-line kills are special-cased to suppress some triggered notices/events
languages []string
lastActive time.Time // last time they sent a command that wasn't PONG or similar
lastSeen map[string]time.Time // maps device ID (including "") to time of last received command
readMarkers map[string]time.Time // maps casefolded target to time of last read marker
loginThrottle connection_limits.GenericThrottle
nextSessionID int64 // Incremented when a new session is established
nick string
nickCasefolded string
nickMaskCasefolded string
nickMaskString string // cache for nickmask string since it's used with lots of replies
oper *Oper
preregNick string
proxiedIP net.IP // actual remote IP if using the PROXY protocol
rawHostname string
cloakedHostname string
realname string
realIP net.IP
requireSASLMessage string
requireSASL bool
registered bool
registerCmdSent bool // already sent the draft/register command, can't send it again
dirtyTimestamps bool // lastSeen or readMarkers is dirty
registrationTimer *time.Timer
server *Server
skeleton string
sessions []*Session
stateMutex sync.RWMutex // tier 1
alwaysOn bool
username string
vhost string
history history.Buffer
dirtyBits uint
writebackLock sync.Mutex // tier 1.5
pushSubscriptions map[string]*pushSubscription
cachedPushSubscriptions []storedPushSubscription
clearablePushMessages map[string]time.Time
pushSubscriptionsExist atomic.Uint32 // this is a cache on len(pushSubscriptions) != 0
pushQueue pushQueue
}
type saslStatus struct {
@ -198,6 +209,8 @@ type Session struct {
autoreplayMissedSince time.Time
batch MultilineBatch
webPushEndpoint string // goroutine-local: web push endpoint registered by the current session
}
// MultilineBatch tracks the state of a client-to-server multiline batch.
@ -407,7 +420,7 @@ func (server *Server) RunClient(conn IRCConn) {
client.run(session)
}
func (server *Server) AddAlwaysOnClient(account ClientAccount, channelToStatus map[string]alwaysOnChannelStatus, lastSeen, readMarkers map[string]time.Time, uModes modes.Modes, realname string) {
func (server *Server) AddAlwaysOnClient(account ClientAccount, channelToStatus map[string]alwaysOnChannelStatus, lastSeen, readMarkers map[string]time.Time, uModes modes.Modes, realname string, pushSubscriptions []storedPushSubscription) {
now := time.Now().UTC()
config := server.Config()
if lastSeen == nil && account.Settings.AutoreplayMissed {
@ -484,6 +497,14 @@ func (server *Server) AddAlwaysOnClient(account ClientAccount, channelToStatus m
if persistenceEnabled(config.Accounts.Multiclient.AutoAway, client.accountSettings.AutoAway) {
client.setAutoAwayNoMutex(config)
}
if len(pushSubscriptions) != 0 {
client.pushSubscriptions = make(map[string]*pushSubscription, len(pushSubscriptions))
for _, sub := range pushSubscriptions {
client.pushSubscriptions[sub.Endpoint] = newPushSubscription(sub)
}
}
client.rebuildPushSubscriptionCache()
}
func (client *Client) resizeHistory(config *Config) {
@ -1780,6 +1801,7 @@ const (
IncludeChannels uint = 1 << iota
IncludeUserModes
IncludeRealname
IncludePushSubscriptions
)
func (client *Client) markDirty(dirtyBits uint) {
@ -1800,7 +1822,7 @@ func (client *Client) wakeWriter() {
}
func (client *Client) writeLoop() {
defer client.server.HandlePanic()
defer client.server.HandlePanic(nil)
for {
client.performWrite(0)
@ -1858,6 +1880,9 @@ func (client *Client) performWrite(additionalDirtyBits uint) {
if (dirtyBits & IncludeRealname) != 0 {
client.server.accounts.saveRealname(account, client.realname)
}
if (dirtyBits & IncludePushSubscriptions) != 0 {
client.server.accounts.savePushSubscriptions(account, client.getPushSubscriptions())
}
}
// Blocking store; see Channel.Store and Socket.BlockingWrite
@ -1877,3 +1902,134 @@ func (client *Client) Store(dirtyBits uint) (err error) {
client.performWrite(dirtyBits)
return nil
}
// pushSubscription represents all the data we track about the state of a push subscription;
// right now every field is persisted, but we may want to persist only a subset in future
type pushSubscription struct {
storedPushSubscription
}
// storedPushSubscription represents a subscription as stored in the database
type storedPushSubscription struct {
Endpoint string
Keys webpush.Keys
LastRefresh time.Time // last time the client sent WEBPUSH REGISTER for this endpoint
LastSuccess time.Time // last time we successfully pushed to this endpoint
}
func newPushSubscription(sub storedPushSubscription) *pushSubscription {
return &pushSubscription{
storedPushSubscription: sub,
// TODO any other initialization here, like rate limiting
}
}
type pushMessage struct {
msg []byte
urgency webpush.Urgency
originatingEndpoint string
cftarget string
time time.Time
}
type pushQueue struct {
workerLock sync.Mutex
queue chan pushMessage
once sync.Once
dropped atomic.Uint64
}
func (c *Client) ensurePushInitialized() {
c.pushQueue.once.Do(c.initializePush)
}
func (c *Client) initializePush() {
// allocate the queue
c.pushQueue.queue = make(chan pushMessage, pushQueueLengthPerClient)
}
func (client *Client) dispatchPushMessage(msg pushMessage) {
client.ensurePushInitialized()
select {
case client.pushQueue.queue <- msg:
if client.pushQueue.workerLock.TryLock() {
go client.pushWorker()
}
default:
client.pushQueue.dropped.Add(1)
}
}
func (client *Client) pushWorker() {
defer client.server.HandlePanic(nil)
defer client.pushQueue.workerLock.Unlock()
for {
select {
case msg := <-client.pushQueue.queue:
for _, sub := range client.getPushSubscriptions() {
if !client.skipPushMessage(msg) {
client.sendAndTrackPush(sub.Endpoint, sub.Keys, msg, true)
}
}
default:
// no more messages, end the goroutine and release the trylock
return
}
}
}
// skipPushMessage waits up to the configured delay for the client to send MARKREAD;
// it returns whether the message has been read
func (client *Client) skipPushMessage(msg pushMessage) bool {
if msg.cftarget == "" || msg.time.IsZero() {
return false
}
config := client.server.Config()
if config.WebPush.Delay == 0 {
return false
}
deadline := msg.time.Add(config.WebPush.Delay)
pause := time.Until(deadline)
if pause > 0 {
time.Sleep(pause)
}
readTimestamp, ok := client.getMarkreadTime(msg.cftarget)
return ok && utils.ReadMarkerLessThanOrEqual(msg.time, readTimestamp)
}
func (client *Client) sendAndTrackPush(endpoint string, keys webpush.Keys, msg pushMessage, updateDB bool) {
if endpoint == msg.originatingEndpoint {
return
}
if msg.cftarget != "" && !msg.time.IsZero() {
client.addClearablePushMessage(msg.cftarget, msg.time)
}
switch client.sendPush(endpoint, keys, msg.urgency, msg.msg) {
case nil:
client.recordPush(endpoint, true)
case webpush.Err404:
client.deletePushSubscription(endpoint, updateDB)
default:
client.recordPush(endpoint, false)
}
}
func (client *Client) sendPush(endpoint string, keys webpush.Keys, urgency webpush.Urgency, msg []byte) error {
config := client.server.Config()
// final sanity check
if !config.WebPush.Enabled {
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), config.WebPush.Timeout)
defer cancel()
err := webpush.SendWebPush(ctx, endpoint, keys, config.WebPush.vapidKeys, webpush.UrgencyHigh, config.WebPush.Subscriber, msg)
if err == nil {
client.server.logger.Debug("webpush", "dispatched push to client", client.Nick(), endpoint)
} else {
client.server.logger.Debug("webpush", "failed to dispatch push to client", client.Nick(), endpoint, err.Error())
}
return err
}