Ban support.

This commit is contained in:
Mikkel Krautz 2011-05-14 17:22:29 +02:00
parent 2b20d7a555
commit c1861a4312
7 changed files with 329 additions and 0 deletions

View file

@ -13,6 +13,7 @@ PACKAGES = \
pkg/blobstore \ pkg/blobstore \
pkg/serverconf \ pkg/serverconf \
pkg/sessionpool \ pkg/sessionpool \
pkg/ban \
pkg/sqlite pkg/sqlite
GCFLAGS = \ GCFLAGS = \
@ -22,6 +23,7 @@ GCFLAGS = \
-Ipkg/blobstore/_obj \ -Ipkg/blobstore/_obj \
-Ipkg/serverconf/_obj \ -Ipkg/serverconf/_obj \
-Ipkg/sessionpool/_obj \ -Ipkg/sessionpool/_obj \
-Ipkg/ban/_obj \
-Ipkg/sqlite/_obj -Ipkg/sqlite/_obj
LDFLAGS = \ LDFLAGS = \
@ -31,6 +33,7 @@ LDFLAGS = \
-Lpkg/blobstore/_obj \ -Lpkg/blobstore/_obj \
-Lpkg/serverconf/_obj \ -Lpkg/serverconf/_obj \
-Lpkg/sessionpool/_obj \ -Lpkg/sessionpool/_obj \
-Lpkg/ban/_obj \
-Lpkg/sqlite/_obj -Lpkg/sqlite/_obj
GOFILES = \ GOFILES = \

View file

@ -8,6 +8,7 @@ import (
"compress/gzip" "compress/gzip"
"fmt" "fmt"
"gob" "gob"
"grumble/ban"
"grumble/serverconf" "grumble/serverconf"
"io" "io"
"io/ioutil" "io/ioutil"
@ -17,6 +18,7 @@ import (
type frozenServer struct { type frozenServer struct {
Id int "id" Id int "id"
Config map[string]string "config" Config map[string]string "config"
Bans []ban.Ban "bans"
Channels []frozenChannel "channels" Channels []frozenChannel "channels"
Users []frozenUser "users" Users []frozenUser "users"
} }
@ -101,6 +103,11 @@ func (server *Server) Freeze() (fs frozenServer, err os.Error) {
fs.Id = int(server.Id) fs.Id = int(server.Id)
fs.Config = server.cfg.GetAll() fs.Config = server.cfg.GetAll()
server.banlock.RLock()
fs.Bans = make([]ban.Ban, len(server.Bans))
copy(fs.Bans, server.Bans)
server.banlock.RUnlock()
channels := []frozenChannel{} channels := []frozenChannel{}
for _, c := range server.Channels { for _, c := range server.Channels {
fc, err := c.Freeze() fc, err := c.Freeze()
@ -229,6 +236,8 @@ func NewServerFromFrozen(filename string) (s *Server, err os.Error) {
s.cfg = serverconf.New(fs.Config) s.cfg = serverconf.New(fs.Config)
} }
s.Bans = fs.Bans
// Add all channels, but don't hook up parent/child relationships // Add all channels, but don't hook up parent/child relationships
// until all of them are loaded. // until all of them are loaded.
for _, fc := range fs.Channels { for _, fc := range fs.Channels {

View file

@ -11,6 +11,7 @@ import (
"net" "net"
"cryptstate" "cryptstate"
"fmt" "fmt"
"grumble/ban"
"grumble/blobstore" "grumble/blobstore"
) )
@ -850,6 +851,65 @@ func (server *Server) handleUserStateMessage(client *Client, msg *Message) {
} }
func (server *Server) handleBanListMessage(client *Client, msg *Message) { func (server *Server) handleBanListMessage(client *Client, msg *Message) {
banlist := &mumbleproto.BanList{}
err := proto.Unmarshal(msg.buf, banlist)
if err != nil {
client.Panic(err.String())
return
}
if !server.HasPermission(client, server.root, BanPermission) {
client.sendPermissionDenied(client, server.root, BanPermission)
}
if banlist.Query != nil && *banlist.Query != false {
banlist.Reset()
server.banlock.RLock()
defer server.banlock.RUnlock()
for _, ban := range server.Bans {
entry := &mumbleproto.BanList_BanEntry{}
entry.Address = ban.IP
entry.Mask = proto.Uint32(uint32(ban.Mask))
entry.Name = proto.String(ban.Username)
entry.Hash = proto.String(ban.CertHash)
entry.Reason = proto.String(ban.Reason)
entry.Start = proto.String(ban.ISOStartDate())
entry.Duration = proto.Uint32(ban.Duration)
banlist.Bans = append(banlist.Bans, entry)
}
if err := client.sendProtoMessage(MessageBanList, banlist); err != nil {
client.Panic("Unable to send BanList")
}
} else {
server.banlock.Lock()
defer server.banlock.Unlock()
server.Bans = server.Bans[0:0]
for _, entry := range banlist.Bans {
ban := ban.Ban{}
ban.IP = entry.Address
ban.Mask = int(*entry.Mask)
if entry.Name != nil {
ban.Username = *entry.Name
}
if entry.Hash != nil {
ban.CertHash = *entry.Hash
}
if entry.Reason != nil {
ban.Reason = *entry.Reason
}
if entry.Start != nil {
ban.SetISOStartDate(*entry.Start)
}
if entry.Duration != nil {
ban.Duration = *entry.Duration
}
server.Bans = append(server.Bans, ban)
}
client.Printf("Banlist updated")
}
} }
// Broadcast text messages // Broadcast text messages

7
pkg/ban/Makefile Normal file
View file

@ -0,0 +1,7 @@
include $(GOROOT)/src/Make.inc
TARG = grumble/ban
GOFILES = \
ban.go \
include $(GOROOT)/src/Make.pkg

81
pkg/ban/ban.go Normal file
View file

@ -0,0 +1,81 @@
// Copyright (c) 2011 The Grumble Authors
// The use of this source code is goverened by a BSD-style
// license that can be found in the LICENSE-file.
package ban
import (
"net"
"time"
)
const (
ISODate = "2006-01-02T15:04:05"
)
type Ban struct {
IP net.IP
Mask int
Username string
CertHash string
Reason string
Start int64
Duration uint32
}
// Create a net.IPMask from a specified amount of mask bits
func (ban Ban) IPMask() (mask net.IPMask) {
allbits := ban.Mask
for i := 0; i < 16; i++ {
bits := allbits
if bits > 0 {
if bits > 8 {
bits = 8
}
mask = append(mask, byte((1<<uint(bits))-1))
} else {
mask = append(mask, byte(0))
}
allbits -= 8
}
return
}
// Check whether an IP matches a Ban
func (ban Ban) Match(ip net.IP) bool {
banned := ban.IP.Mask(ban.IPMask())
masked := ip.Mask(ban.IPMask())
return banned.Equal(masked)
}
// Set Start date from an ISO 8601 date (in UTC)
func (ban *Ban) SetISOStartDate(isodate string) {
startTime, err := time.Parse(ISODate, isodate)
if err != nil {
ban.Start = 0
} else {
ban.Start = startTime.Seconds()
}
}
// Return the currently set start date as an ISO 8601-formatted
// date (in UTC).
func (ban Ban) ISOStartDate() string {
startTime := time.SecondsToUTC(ban.Start)
return startTime.Format(ISODate)
}
// Check whether a ban has expired
func (ban Ban) IsExpired() bool {
// ∞-case
if ban.Duration == 0 {
return false
}
// Expiry check
expiryTime := ban.Start + int64(ban.Duration)
if time.Seconds() > expiryTime {
return true
}
return false
}

138
pkg/ban/ban_test.go Normal file
View file

@ -0,0 +1,138 @@
package ban
import (
"bytes"
"net"
"testing"
"time"
)
func TestMaskNonPowerOf8(t *testing.T) {
mask := []byte{0xff, 0x1f, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
b := Ban{}
b.Mask = 13
if !bytes.Equal(b.IPMask(), mask) {
t.Errorf("Mask mismatch: %v, %v", mask, []byte(b.IPMask()))
}
}
func TestMaksPowerOf2(t *testing.T) {
mask := []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0}
b := Ban{}
b.Mask = 64
if !bytes.Equal(b.IPMask(), mask) {
t.Errorf("Mask mismatch: %v, %v", mask, []byte(b.IPMask()))
}
}
func TestMatchV4(t *testing.T) {
b := Ban{}
b.IP = net.ParseIP("192.168.1.1")
b.Mask = 24+96 // ipv4 /24
if len(b.IP) == 0 {
t.Errorf("Invalid IP")
}
clientIp := net.ParseIP("192.168.1.50")
if len(clientIp) == 0 {
t.Errorf("Invalid IP")
}
if b.Match(clientIp) != true {
t.Errorf("IPv4: unexpected match")
}
}
func TestMismatchV4(t *testing.T) {
b := Ban{}
b.IP = net.ParseIP("192.168.1.1")
b.Mask = 24+96 // ipv4 /24
if len(b.IP) == 0 {
t.Errorf("Invalid IP")
}
clientIp := net.ParseIP("192.168.2.1")
if len(clientIp) == 0 {
t.Errorf("Invalid IP")
}
if b.Match(clientIp) == true {
t.Errorf("IPv4: unexpected mismatch")
}
}
func TestMatchV6(t *testing.T) {
b := Ban {}
b.IP = net.ParseIP("2a00:1450:400b:c00::63")
b.Mask = 64
if len(b.IP) == 0 {
t.Errorf("Invalid IP")
}
clientIp := net.ParseIP("2a00:1450:400b:c00::54")
if len(clientIp) == 0 {
t.Errorf("Invalid IP")
}
if b.Match(clientIp) != true {
t.Errorf("IPv6: unexpected match")
}
}
func TestMismatchV6(t *testing.T) {
b := Ban{}
b.IP = net.ParseIP("2a00:1450:400b:c00::63")
b.Mask = 64
if len(b.IP) == 0 {
t.Errorf("Invalid IP")
}
clientIp := net.ParseIP("2a00:1450:400b:deaf:42f0:cafe:babe:54")
if len(clientIp) == 0 {
t.Errorf("Invalid IP")
}
if b.Match(clientIp) == true {
t.Errorf("IPv6: unexpected mismatch")
}
}
func TestISODate(t *testing.T) {
sometime := "2011-05-14T13:48:00"
b := Ban{}
b.SetISOStartDate(sometime)
if sometime != b.ISOStartDate() {
t.Errorf("UNIX timestamp mismatch: %v %v", b.ISOStartDate(), sometime)
}
}
func TestInfiniteExpiry(t *testing.T) {
b := Ban{}
b.Start = time.Seconds()-10
b.Duration = 0
if b.IsExpired() {
t.Errorf("∞ should not expire")
}
}
func TestExpired(t *testing.T) {
b := Ban{}
b.Start = time.Seconds()-10
b.Duration = 9
if !b.IsExpired() {
t.Errorf("Should have expired 1 second ago")
}
}
func TestNotExpired(t *testing.T) {
b := Ban{}
b.Start = time.Seconds()
b.Duration = 60*60*24
if b.IsExpired() {
t.Errorf("Should expire in 24 hours")
}
}

View file

@ -21,6 +21,7 @@ import (
"cryptstate" "cryptstate"
"fmt" "fmt"
"gob" "gob"
"grumble/ban"
"grumble/blobstore" "grumble/blobstore"
"grumble/serverconf" "grumble/serverconf"
"grumble/sessionpool" "grumble/sessionpool"
@ -97,6 +98,10 @@ type Server struct {
// ACL cache // ACL cache
aclcache ACLCache aclcache ACLCache
// Bans
banlock sync.RWMutex
Bans []ban.Ban
// Logging // Logging
*log.Logger *log.Logger
} }
@ -1104,6 +1109,21 @@ func (s *Server) removeRegisteredUserFromChannel(uid uint32, channel *Channel) {
} }
} }
// Is the incoming connection conn banned?
func (server *Server) IsBanned(conn net.Conn) bool {
server.banlock.RLock()
defer server.banlock.RUnlock()
for _, ban := range server.Bans {
addr := conn.RemoteAddr().(*net.TCPAddr)
if ban.Match(addr.IP) && !ban.IsExpired() {
return true
}
}
return false
}
// The accept loop of the server. // The accept loop of the server.
func (s *Server) ListenAndMurmur() { func (s *Server) ListenAndMurmur() {
// Launch the event handler goroutine // Launch the event handler goroutine
@ -1159,6 +1179,17 @@ func (s *Server) ListenAndMurmur() {
continue continue
} }
// Is the client banned?
// fixme(mkrautz): Clean up expired bans
if s.IsBanned(conn) {
s.Printf("Rejected client %v: Banned", conn.RemoteAddr())
err := conn.Close()
if err != nil {
s.Printf("Unable to close connection: %v", err)
}
continue
}
// Create a new client connection from our *tls.Conn // Create a new client connection from our *tls.Conn
// which wraps net.TCPConn. // which wraps net.TCPConn.
err = s.NewClient(conn) err = s.NewClient(conn)