diff --git a/Makefile b/Makefile index e69da6c4..2be9474a 100644 --- a/Makefile +++ b/Makefile @@ -32,6 +32,7 @@ test: cd irc/modes && go test . && go vet . cd irc/mysql && go test . && go vet . cd irc/passwd && go test . && go vet . + cd irc/sno && go test . && go vet . cd irc/utils && go test . && go vet . ./.check-gofmt.sh diff --git a/irc/modes.go b/irc/modes.go index e4396793..f8a1e85d 100644 --- a/irc/modes.go +++ b/irc/modes.go @@ -75,30 +75,28 @@ func ApplyUserModeChanges(client *Client, changes modes.ModeChanges, force bool, } } else { // server notices are weird - if !client.HasMode(modes.Operator) { + if !client.HasMode(modes.Operator) || change.Op == modes.List { continue } - var masks []sno.Mask - if change.Op == modes.Add || change.Op == modes.Remove { - var newArg string - for _, char := range change.Arg { - mask := sno.Mask(char) - if sno.ValidMasks[mask] { - masks = append(masks, mask) - newArg += string(char) - } - } - change.Arg = newArg - } - if change.Op == modes.Add { + + currentMasks := client.server.snomasks.MasksEnabled(client) + addMasks, removeMasks, newArg := sno.EvaluateSnomaskChanges(change.Op == modes.Add, change.Arg, currentMasks) + + success := false + if len(addMasks) != 0 { oper := client.Oper() // #1176: require special operator privileges to subscribe to snomasks if oper.HasRoleCapab("snomasks") || oper.HasRoleCapab("ban") { - client.server.snomasks.AddMasks(client, masks...) - applied = append(applied, change) + success = true + client.server.snomasks.AddMasks(client, addMasks...) } - } else if change.Op == modes.Remove { - client.server.snomasks.RemoveMasks(client, masks...) + } + if len(removeMasks) != 0 { + success = true + client.server.snomasks.RemoveMasks(client, removeMasks...) + } + if success { + change.Arg = newArg applied = append(applied, change) } } diff --git a/irc/modes/modes.go b/irc/modes/modes.go index f7efedf5..f91cc06b 100644 --- a/irc/modes/modes.go +++ b/irc/modes/modes.go @@ -212,12 +212,10 @@ func ParseUserModeChanges(params ...string) (ModeChanges, map[rune]bool) { // put arg into modechange if needed switch Mode(mode) { case ServerNotice: - // always require arg + // arg is optional for ServerNotice (we accept bare `-s`) if len(params) > skipArgs { change.Arg = params[skipArgs] skipArgs++ - } else { - continue } } diff --git a/irc/modes/modes_test.go b/irc/modes/modes_test.go index 27f04888..67d28c2f 100644 --- a/irc/modes/modes_test.go +++ b/irc/modes/modes_test.go @@ -15,6 +15,38 @@ func assertEqual(supplied, expected interface{}, t *testing.T) { } } +func TestParseUserModeChanges(t *testing.T) { + emptyUnknown := make(map[rune]bool) + changes, unknown := ParseUserModeChanges("+i") + assertEqual(unknown, emptyUnknown, t) + assertEqual(changes, ModeChanges{ModeChange{Op: Add, Mode: Invisible}}, t) + + // no-op change to sno + changes, unknown = ParseUserModeChanges("+is") + assertEqual(unknown, emptyUnknown, t) + assertEqual(changes, ModeChanges{ModeChange{Op: Add, Mode: Invisible}, ModeChange{Op: Add, Mode: ServerNotice}}, t) + + // add snomasks + changes, unknown = ParseUserModeChanges("+is", "ac") + assertEqual(unknown, emptyUnknown, t) + assertEqual(changes, ModeChanges{ModeChange{Op: Add, Mode: Invisible}, ModeChange{Op: Add, Mode: ServerNotice, Arg: "ac"}}, t) + + // remove snomasks + changes, unknown = ParseUserModeChanges("+s", "-cx") + assertEqual(unknown, emptyUnknown, t) + assertEqual(changes, ModeChanges{ModeChange{Op: Add, Mode: ServerNotice, Arg: "-cx"}}, t) + + // remove all snomasks (arg is parsed but has no meaning) + changes, unknown = ParseUserModeChanges("-is", "ac") + assertEqual(unknown, emptyUnknown, t) + assertEqual(changes, ModeChanges{ModeChange{Op: Remove, Mode: Invisible}, ModeChange{Op: Remove, Mode: ServerNotice, Arg: "ac"}}, t) + + // remove all snomasks + changes, unknown = ParseUserModeChanges("-is") + assertEqual(unknown, emptyUnknown, t) + assertEqual(changes, ModeChanges{ModeChange{Op: Remove, Mode: Invisible}, ModeChange{Op: Remove, Mode: ServerNotice}}, t) +} + func TestIssue874(t *testing.T) { emptyUnknown := make(map[rune]bool) modes, unknown := ParseChannelModeChanges("+k") diff --git a/irc/sno/constants.go b/irc/sno/constants.go index 1eb1bc5a..542ab57f 100644 --- a/irc/sno/constants.go +++ b/irc/sno/constants.go @@ -7,6 +7,8 @@ package sno // Mask is a type of server notice mask. type Mask rune +type Masks []Mask + // Notice mask types const ( LocalAnnouncements Mask = 'a' @@ -18,8 +20,8 @@ const ( LocalQuits Mask = 'q' Stats Mask = 't' LocalAccounts Mask = 'u' - LocalXline Mask = 'x' LocalVhosts Mask = 'v' + LocalXline Mask = 'x' ) var ( @@ -39,17 +41,17 @@ var ( } // ValidMasks contains the snomasks that we support. - ValidMasks = map[Mask]bool{ - LocalAnnouncements: true, - LocalConnects: true, - LocalChannels: true, - LocalKills: true, - LocalNicks: true, - LocalOpers: true, - LocalQuits: true, - Stats: true, - LocalAccounts: true, - LocalXline: true, - LocalVhosts: true, + ValidMasks = []Mask{ + LocalAnnouncements, + LocalConnects, + LocalChannels, + LocalKills, + LocalNicks, + LocalOpers, + LocalQuits, + Stats, + LocalAccounts, + LocalVhosts, + LocalXline, } ) diff --git a/irc/sno/utils.go b/irc/sno/utils.go new file mode 100644 index 00000000..572ab394 --- /dev/null +++ b/irc/sno/utils.go @@ -0,0 +1,87 @@ +// Copyright (c) 2020 Shivaram Lingamneni +// released under the MIT license + +package sno + +import ( + "strings" +) + +func IsValidMask(r rune) bool { + for _, m := range ValidMasks { + if m == Mask(r) { + return true + } + } + return false +} + +func (masks Masks) String() string { + var buf strings.Builder + buf.Grow(len(masks)) + for _, m := range masks { + buf.WriteRune(rune(m)) + } + return buf.String() +} + +func (masks Masks) Contains(mask Mask) bool { + for _, m := range masks { + if mask == m { + return true + } + } + return false +} + +// Evaluate changes to snomasks made with MODE. There are several cases: +// adding snomasks with `/mode +s a` or `/mode +s +a`, removing them with `/mode +s -a`, +// adding all with `/mode +s *` or `/mode +s +*`, removing all with `/mode +s -*` or `/mode -s` +func EvaluateSnomaskChanges(add bool, arg string, currentMasks Masks) (addMasks, removeMasks Masks, newArg string) { + if add { + if len(arg) == 0 { + return + } + add := true + switch arg[0] { + case '+': + arg = arg[1:] + case '-': + add = false + arg = arg[1:] + default: + // add + } + if strings.IndexByte(arg, '*') != -1 { + if add { + for _, mask := range ValidMasks { + if !currentMasks.Contains(mask) { + addMasks = append(addMasks, mask) + } + } + } else { + removeMasks = currentMasks + } + } else { + for _, r := range arg { + if IsValidMask(r) { + m := Mask(r) + if add && !currentMasks.Contains(m) { + addMasks = append(addMasks, m) + } else if !add && currentMasks.Contains(m) { + removeMasks = append(removeMasks, m) + } + } + } + } + if len(addMasks) != 0 { + newArg = "+" + addMasks.String() + } else if len(removeMasks) != 0 { + newArg = "-" + removeMasks.String() + } + } else { + removeMasks = currentMasks + newArg = "" + } + return +} diff --git a/irc/sno/utils_test.go b/irc/sno/utils_test.go new file mode 100644 index 00000000..46ce30bf --- /dev/null +++ b/irc/sno/utils_test.go @@ -0,0 +1,53 @@ +// Copyright (c) 2020 Shivaram Lingamneni +// released under the MIT license + +package sno + +import ( + "fmt" + "reflect" + "testing" +) + +func assertEqual(supplied, expected interface{}, t *testing.T) { + if !reflect.DeepEqual(supplied, expected) { + panic(fmt.Sprintf("expected %#v but got %#v", expected, supplied)) + } +} + +func TestEvaluateSnomaskChanges(t *testing.T) { + add, remove, newArg := EvaluateSnomaskChanges(true, "*", nil) + assertEqual(add, Masks{'a', 'c', 'j', 'k', 'n', 'o', 'q', 't', 'u', 'v', 'x'}, t) + assertEqual(len(remove), 0, t) + assertEqual(newArg, "+acjknoqtuvx", t) + + add, remove, newArg = EvaluateSnomaskChanges(true, "*", Masks{'a', 'u'}) + assertEqual(add, Masks{'c', 'j', 'k', 'n', 'o', 'q', 't', 'v', 'x'}, t) + assertEqual(len(remove), 0, t) + assertEqual(newArg, "+cjknoqtvx", t) + + add, remove, newArg = EvaluateSnomaskChanges(true, "-a", Masks{'a', 'u'}) + assertEqual(len(add), 0, t) + assertEqual(remove, Masks{'a'}, t) + assertEqual(newArg, "-a", t) + + add, remove, newArg = EvaluateSnomaskChanges(true, "-*", Masks{'a', 'u'}) + assertEqual(len(add), 0, t) + assertEqual(remove, Masks{'a', 'u'}, t) + assertEqual(newArg, "-au", t) + + add, remove, newArg = EvaluateSnomaskChanges(true, "+c", Masks{'a', 'u'}) + assertEqual(add, Masks{'c'}, t) + assertEqual(len(remove), 0, t) + assertEqual(newArg, "+c", t) + + add, remove, newArg = EvaluateSnomaskChanges(false, "", Masks{'a', 'u'}) + assertEqual(len(add), 0, t) + assertEqual(remove, Masks{'a', 'u'}, t) + assertEqual(newArg, "", t) + + add, remove, newArg = EvaluateSnomaskChanges(false, "*", Masks{'a', 'u'}) + assertEqual(len(add), 0, t) + assertEqual(remove, Masks{'a', 'u'}, t) + assertEqual(newArg, "", t) +} diff --git a/irc/snomanager.go b/irc/snomanager.go index b7e48107..0bd394bc 100644 --- a/irc/snomanager.go +++ b/irc/snomanager.go @@ -24,11 +24,6 @@ func (m *SnoManager) AddMasks(client *Client, masks ...sno.Mask) { defer m.sendListMutex.Unlock() for _, mask := range masks { - // confirm mask is valid - if !sno.ValidMasks[mask] { - continue - } - currentClientList := m.sendLists[mask] if currentClientList == nil { @@ -101,19 +96,23 @@ func (m *SnoManager) Send(mask sno.Mask, content string) { } } -// String returns the snomasks currently enabled. -func (m *SnoManager) String(client *Client) string { +// MasksEnabled returns the snomasks currently enabled. +func (m *SnoManager) MasksEnabled(client *Client) (result sno.Masks) { m.sendListMutex.RLock() defer m.sendListMutex.RUnlock() - var masks string for mask, clients := range m.sendLists { for c := range clients { if c == client { - masks += string(mask) + result = append(result, mask) break } } } - return masks + return +} + +func (m *SnoManager) String(client *Client) string { + masks := m.MasksEnabled(client) + return masks.String() }