mirror of
https://github.com/mumble-voip/grumble.git
synced 2025-12-20 06:10:00 -08:00
Merge 76ff3fb907 into 6f8c2bf2f5
This commit is contained in:
commit
0cb35a7e20
17 changed files with 522 additions and 160 deletions
|
|
@ -21,7 +21,7 @@ var usageTmpl = `usage: grumble [options]
|
||||||
grumble {{.Version}} ({{.BuildDate}})
|
grumble {{.Version}} ({{.BuildDate}})
|
||||||
target: {{.OS}}, {{.Arch}}
|
target: {{.OS}}, {{.Arch}}
|
||||||
|
|
||||||
--help
|
--help, --version
|
||||||
Shows this help listing.
|
Shows this help listing.
|
||||||
|
|
||||||
--datadir <data-dir> (default: {{.DefaultDataDir}})
|
--datadir <data-dir> (default: {{.DefaultDataDir}})
|
||||||
|
|
@ -30,6 +30,20 @@ var usageTmpl = `usage: grumble [options]
|
||||||
--log <log-path> (default: $DATADIR/grumble.log)
|
--log <log-path> (default: $DATADIR/grumble.log)
|
||||||
Log file path.
|
Log file path.
|
||||||
|
|
||||||
|
--ini <config-path> (default: $DATADIR/grumble.ini)
|
||||||
|
Config file path.
|
||||||
|
|
||||||
|
--supw <password> [server-id]
|
||||||
|
Set password for SuperUser account. Optionally takes
|
||||||
|
the virtual server to modify as the first positional argument.
|
||||||
|
|
||||||
|
--readsupw [server-id]
|
||||||
|
Like --supw, but reads from stdin instead.
|
||||||
|
|
||||||
|
--disablesu [server-id]
|
||||||
|
Disables the SuperUser account. Optionally takes
|
||||||
|
the virtual server to modify as the first positional argument.
|
||||||
|
|
||||||
--regen-keys
|
--regen-keys
|
||||||
Force grumble to regenerate its global RSA
|
Force grumble to regenerate its global RSA
|
||||||
keypair (and certificate).
|
keypair (and certificate).
|
||||||
|
|
@ -49,7 +63,12 @@ type args struct {
|
||||||
ShowHelp bool
|
ShowHelp bool
|
||||||
DataDir string
|
DataDir string
|
||||||
LogPath string
|
LogPath string
|
||||||
|
ConfigPath string
|
||||||
|
SuperUserPW string
|
||||||
|
ReadPass bool
|
||||||
|
DisablePass bool
|
||||||
RegenKeys bool
|
RegenKeys bool
|
||||||
|
ServerId int64
|
||||||
SQLiteDB string
|
SQLiteDB string
|
||||||
CleanUp bool
|
CleanUp bool
|
||||||
}
|
}
|
||||||
|
|
@ -63,10 +82,6 @@ func defaultDataDir() string {
|
||||||
return filepath.Join(homedir, dirname)
|
return filepath.Join(homedir, dirname)
|
||||||
}
|
}
|
||||||
|
|
||||||
func defaultLogPath() string {
|
|
||||||
return filepath.Join(defaultDataDir(), "grumble.log")
|
|
||||||
}
|
|
||||||
|
|
||||||
func Usage() {
|
func Usage() {
|
||||||
t, err := template.New("usage").Parse(usageTmpl)
|
t, err := template.New("usage").Parse(usageTmpl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -90,9 +105,17 @@ var Args args
|
||||||
func init() {
|
func init() {
|
||||||
flag.Usage = Usage
|
flag.Usage = Usage
|
||||||
|
|
||||||
|
flag.BoolVar(&Args.ShowHelp, "version", false, "")
|
||||||
flag.BoolVar(&Args.ShowHelp, "help", false, "")
|
flag.BoolVar(&Args.ShowHelp, "help", false, "")
|
||||||
|
|
||||||
flag.StringVar(&Args.DataDir, "datadir", defaultDataDir(), "")
|
flag.StringVar(&Args.DataDir, "datadir", defaultDataDir(), "")
|
||||||
flag.StringVar(&Args.LogPath, "log", defaultLogPath(), "")
|
flag.StringVar(&Args.LogPath, "log", "", "")
|
||||||
|
flag.StringVar(&Args.ConfigPath, "ini", "", "")
|
||||||
|
|
||||||
|
flag.StringVar(&Args.SuperUserPW, "supw", "", "")
|
||||||
|
flag.BoolVar(&Args.ReadPass, "readsupw", false, "")
|
||||||
|
flag.BoolVar(&Args.DisablePass, "disablesu", false, "")
|
||||||
|
|
||||||
flag.BoolVar(&Args.RegenKeys, "regen-keys", false, "")
|
flag.BoolVar(&Args.RegenKeys, "regen-keys", false, "")
|
||||||
|
|
||||||
flag.StringVar(&Args.SQLiteDB, "import-murmurdb", "", "")
|
flag.StringVar(&Args.SQLiteDB, "import-murmurdb", "", "")
|
||||||
|
|
|
||||||
|
|
@ -510,7 +510,7 @@ func (client *Client) tlsRecvLoop() {
|
||||||
Release: proto.String("Grumble"),
|
Release: proto.String("Grumble"),
|
||||||
CryptoModes: cryptstate.SupportedModes(),
|
CryptoModes: cryptstate.SupportedModes(),
|
||||||
}
|
}
|
||||||
if client.server.cfg.BoolValue("SendOSInfo") {
|
if client.server.cfg.BoolValue("sendversion") {
|
||||||
version.Os = proto.String(runtime.GOOS)
|
version.Os = proto.String(runtime.GOOS)
|
||||||
version.OsVersion = proto.String("(Unknown version)")
|
version.OsVersion = proto.String("(Unknown version)")
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,6 @@ import (
|
||||||
"mumble.info/grumble/pkg/ban"
|
"mumble.info/grumble/pkg/ban"
|
||||||
"mumble.info/grumble/pkg/freezer"
|
"mumble.info/grumble/pkg/freezer"
|
||||||
"mumble.info/grumble/pkg/mumbleproto"
|
"mumble.info/grumble/pkg/mumbleproto"
|
||||||
"mumble.info/grumble/pkg/serverconf"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Freeze a server to disk and closes the log file.
|
// Freeze a server to disk and closes the log file.
|
||||||
|
|
@ -74,7 +73,7 @@ func (server *Server) Freeze() (fs *freezer.Server, err error) {
|
||||||
fs = new(freezer.Server)
|
fs = new(freezer.Server)
|
||||||
|
|
||||||
// Freeze all config kv-pairs
|
// Freeze all config kv-pairs
|
||||||
allCfg := server.cfg.GetAll()
|
allCfg := server.cfg.GetAllPersistent()
|
||||||
for k, v := range allCfg {
|
for k, v := range allCfg {
|
||||||
fs.Config = append(fs.Config, &freezer.ConfigKeyValuePair{
|
fs.Config = append(fs.Config, &freezer.ConfigKeyValuePair{
|
||||||
Key: proto.String(k),
|
Key: proto.String(k),
|
||||||
|
|
@ -420,11 +419,10 @@ func NewServerFromFrozen(name string) (s *Server, err error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s, err = NewServer(id)
|
s, err = NewServer(id, configFile.ServerConfig(id, cfgMap))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
s.cfg = serverconf.New(cfgMap)
|
|
||||||
|
|
||||||
// Unfreeze the server's frozen bans.
|
// Unfreeze the server's frozen bans.
|
||||||
s.UnfreezeBanList(fs.BanList)
|
s.UnfreezeBanList(fs.BanList)
|
||||||
|
|
@ -835,29 +833,3 @@ func (server *Server) UpdateFrozenBans(bans []ban.Ban) {
|
||||||
}
|
}
|
||||||
server.numLogOps += 1
|
server.numLogOps += 1
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateConfig writes an updated config value to the datastore.
|
|
||||||
func (server *Server) UpdateConfig(key, value string) {
|
|
||||||
fcfg := &freezer.ConfigKeyValuePair{
|
|
||||||
Key: proto.String(key),
|
|
||||||
Value: proto.String(value),
|
|
||||||
}
|
|
||||||
err := server.freezelog.Put(fcfg)
|
|
||||||
if err != nil {
|
|
||||||
server.Fatal(err)
|
|
||||||
}
|
|
||||||
server.numLogOps += 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetConfig writes to the freezelog that the config with key
|
|
||||||
// has been reset to its default value.
|
|
||||||
func (server *Server) ResetConfig(key string) {
|
|
||||||
fcfg := &freezer.ConfigKeyValuePair{
|
|
||||||
Key: proto.String(key),
|
|
||||||
}
|
|
||||||
err := server.freezelog.Put(fcfg)
|
|
||||||
if err != nil {
|
|
||||||
server.Fatal(err)
|
|
||||||
}
|
|
||||||
server.numLogOps += 1
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -7,17 +7,21 @@ package main
|
||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
"mumble.info/grumble/pkg/blobstore"
|
"mumble.info/grumble/pkg/blobstore"
|
||||||
"mumble.info/grumble/pkg/logtarget"
|
"mumble.info/grumble/pkg/logtarget"
|
||||||
|
"mumble.info/grumble/pkg/serverconf"
|
||||||
)
|
)
|
||||||
|
|
||||||
var servers map[int64]*Server
|
var servers map[int64]*Server
|
||||||
var blobStore blobstore.BlobStore
|
var blobStore blobstore.BlobStore
|
||||||
|
var configFile *serverconf.ConfigFile
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
var err error
|
var err error
|
||||||
|
|
@ -27,6 +31,20 @@ func main() {
|
||||||
Usage()
|
Usage()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if Args.ReadPass {
|
||||||
|
data, err := ioutil.ReadAll(os.Stdin)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to read password from stdin: %v", err)
|
||||||
|
}
|
||||||
|
Args.SuperUserPW = string(data)
|
||||||
|
}
|
||||||
|
if flag.NArg() > 0 && (Args.SuperUserPW != "" || Args.DisablePass) {
|
||||||
|
Args.ServerId, err = strconv.ParseInt(flag.Arg(0), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to parse server id %v: %v", flag.Arg(0), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Open the data dir to check whether it exists.
|
// Open the data dir to check whether it exists.
|
||||||
dataDir, err := os.Open(Args.DataDir)
|
dataDir, err := os.Open(Args.DataDir)
|
||||||
|
|
@ -36,10 +54,42 @@ func main() {
|
||||||
}
|
}
|
||||||
dataDir.Close()
|
dataDir.Close()
|
||||||
|
|
||||||
// Set up logging
|
// Open the config file
|
||||||
logtarget.Default, err = logtarget.OpenFile(Args.LogPath, os.Stderr)
|
var configFn string
|
||||||
|
if Args.ConfigPath != "" {
|
||||||
|
configFn = Args.ConfigPath
|
||||||
|
} else {
|
||||||
|
configFn = filepath.Join(Args.DataDir, "grumble.ini")
|
||||||
|
}
|
||||||
|
if filepath.Ext(configFn) == ".ini" {
|
||||||
|
// Create it if it doesn't exist
|
||||||
|
configFd, err := os.OpenFile(configFn, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0700)
|
||||||
|
if err == nil {
|
||||||
|
configFd.WriteString(serverconf.DefaultConfigFile)
|
||||||
|
log.Fatalf("Default config written to %v\n", configFn)
|
||||||
|
configFd.Close()
|
||||||
|
} else if err != nil && !os.IsExist(err) {
|
||||||
|
log.Fatalf("Unable to open config file (%v): %v", configFn, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
configFile, err = serverconf.NewConfigFile(configFn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "Unable to open log file (%v): %v", Args.LogPath, err)
|
log.Fatalf("Unable to open config file (%v): %v", configFn, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
config := configFile.GlobalConfig()
|
||||||
|
|
||||||
|
// Set up logging
|
||||||
|
var logFn string
|
||||||
|
if Args.LogPath != "" {
|
||||||
|
logFn = Args.LogPath
|
||||||
|
} else {
|
||||||
|
logFn = config.PathValue("logfile", Args.DataDir)
|
||||||
|
}
|
||||||
|
logtarget.Default, err = logtarget.OpenFile(logFn, os.Stderr)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Unable to open log file (%v): %v", logFn, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.SetPrefix("[G] ")
|
log.SetPrefix("[G] ")
|
||||||
|
|
@ -48,6 +98,23 @@ func main() {
|
||||||
log.Printf("Grumble")
|
log.Printf("Grumble")
|
||||||
log.Printf("Using data directory: %s", Args.DataDir)
|
log.Printf("Using data directory: %s", Args.DataDir)
|
||||||
|
|
||||||
|
// Warn on some unsupported configuration options for users migrating from Murmur
|
||||||
|
if config.StringValue("database") != "" {
|
||||||
|
log.Println("* Grumble does not yet support Murmur databases directly (see issue #21 on github).")
|
||||||
|
if driver := config.StringValue("dbDriver"); driver == "QSQLITE" {
|
||||||
|
log.Println(" To convert a previous SQLite database, use the --import-murmurdb flag.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if config.StringValue("sslDHParams") != "" {
|
||||||
|
log.Println("* Go does not implement DHE modes in TLS, so the configured dhparams are ignored.")
|
||||||
|
}
|
||||||
|
if config.StringValue("ice") != "" {
|
||||||
|
log.Println("* Grumble does not support ZeroC ICE.")
|
||||||
|
}
|
||||||
|
if config.StringValue("grpc") != "" {
|
||||||
|
log.Println("* Grumble does not yet support gRPC (see issue #23 on github).")
|
||||||
|
}
|
||||||
|
|
||||||
// Open the blobstore. If the directory doesn't
|
// Open the blobstore. If the directory doesn't
|
||||||
// already exist, create the directory and open
|
// already exist, create the directory and open
|
||||||
// the blobstore.
|
// the blobstore.
|
||||||
|
|
@ -63,11 +130,9 @@ func main() {
|
||||||
|
|
||||||
// Check whether we should regenerate the default global keypair
|
// Check whether we should regenerate the default global keypair
|
||||||
// and corresponding certificate.
|
// and corresponding certificate.
|
||||||
// These are used as the default certificate of all virtual servers
|
// These are used as the default certificate of all virtual servers.
|
||||||
// and the SSH admin console, but can be overridden using the "key"
|
certFn := config.PathValue("sslCert", Args.DataDir)
|
||||||
// and "cert" arguments to Grumble.
|
keyFn := config.PathValue("sslKey", Args.DataDir)
|
||||||
certFn := filepath.Join(Args.DataDir, "cert.pem")
|
|
||||||
keyFn := filepath.Join(Args.DataDir, "key.pem")
|
|
||||||
shouldRegen := false
|
shouldRegen := false
|
||||||
if Args.RegenKeys {
|
if Args.RegenKeys {
|
||||||
shouldRegen = true
|
shouldRegen = true
|
||||||
|
|
@ -164,10 +229,10 @@ func main() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Unable to read file from data directory: %v", err.Error())
|
log.Fatalf("Unable to read file from data directory: %v", err.Error())
|
||||||
}
|
}
|
||||||
// The data dir file descriptor.
|
// The servers dir file descriptor.
|
||||||
err = serversDir.Close()
|
err = serversDir.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Unable to close data directory: %v", err.Error())
|
log.Fatalf("Unable to close servers directory: %v", err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -181,6 +246,18 @@ func main() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Unable to load server: %v", err.Error())
|
log.Fatalf("Unable to load server: %v", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if SuperUser password should be updated.
|
||||||
|
if Args.ServerId == 0 || Args.ServerId == s.Id {
|
||||||
|
if Args.DisablePass {
|
||||||
|
s.cfg.Reset("SuperUserPassword")
|
||||||
|
log.Printf("Disabled SuperUser for server %v", name)
|
||||||
|
} else if Args.SuperUserPW != "" {
|
||||||
|
s.SetSuperUserPassword(Args.SuperUserPW)
|
||||||
|
log.Printf("Set SuperUser password for server %v", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
err = s.FreezeToFile()
|
err = s.FreezeToFile()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Unable to freeze server to disk: %v", err.Error())
|
log.Fatalf("Unable to freeze server to disk: %v", err.Error())
|
||||||
|
|
@ -189,9 +266,17 @@ func main() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If SuperUser password flags were passed, the servers should not start.
|
||||||
|
if Args.SuperUserPW != "" || Args.DisablePass {
|
||||||
|
if len(servers) == 0 {
|
||||||
|
log.Fatalf("No servers found to set password for")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// If no servers were found, create the default virtual server.
|
// If no servers were found, create the default virtual server.
|
||||||
if len(servers) == 0 {
|
if len(servers) == 0 {
|
||||||
s, err := NewServer(1)
|
s, err := NewServer(1, configFile.ServerConfig(1, nil))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Couldn't start server: %s", err.Error())
|
log.Fatalf("Couldn't start server: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -593,7 +593,7 @@ func (server *Server) handleUserStateMessage(client *Client, msg *Message) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
maxChannelUsers := server.cfg.IntValue("MaxChannelUsers")
|
maxChannelUsers := server.cfg.IntValue("usersperchannel")
|
||||||
if maxChannelUsers != 0 && len(dstChan.clients) >= maxChannelUsers {
|
if maxChannelUsers != 0 && len(dstChan.clients) >= maxChannelUsers {
|
||||||
client.sendPermissionDeniedFallback(mumbleproto.PermissionDenied_ChannelFull,
|
client.sendPermissionDeniedFallback(mumbleproto.PermissionDenied_ChannelFull,
|
||||||
0x010201, "Channel is full")
|
0x010201, "Channel is full")
|
||||||
|
|
@ -653,7 +653,7 @@ func (server *Server) handleUserStateMessage(client *Client, msg *Message) {
|
||||||
|
|
||||||
// Texture change
|
// Texture change
|
||||||
if userstate.Texture != nil {
|
if userstate.Texture != nil {
|
||||||
maximg := server.cfg.IntValue("MaxImageMessageLength")
|
maximg := server.cfg.IntValue("imagemessagelength")
|
||||||
if maximg > 0 && len(userstate.Texture) > maximg {
|
if maximg > 0 && len(userstate.Texture) > maximg {
|
||||||
client.sendPermissionDeniedType(mumbleproto.PermissionDenied_TextTooLong)
|
client.sendPermissionDeniedType(mumbleproto.PermissionDenied_TextTooLong)
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -86,7 +86,7 @@ func MurmurImport(filename string) (err error) {
|
||||||
|
|
||||||
// Create a new Server from a Murmur SQLite database
|
// Create a new Server from a Murmur SQLite database
|
||||||
func NewServerFromSQLite(id int64, db *sql.DB) (s *Server, err error) {
|
func NewServerFromSQLite(id int64, db *sql.DB) (s *Server, err error) {
|
||||||
s, err = NewServer(id)
|
s, err = NewServer(id, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -39,16 +39,19 @@ const registerUrl = "https://mumble.info/register.cgi"
|
||||||
// This function is used to determine whether or not to periodically
|
// This function is used to determine whether or not to periodically
|
||||||
// contact the master server list and update this server's metadata.
|
// contact the master server list and update this server's metadata.
|
||||||
func (server *Server) IsPublic() bool {
|
func (server *Server) IsPublic() bool {
|
||||||
if len(server.cfg.StringValue("RegisterName")) == 0 {
|
if len(server.cfg.StringValue("registerName")) == 0 {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if len(server.cfg.StringValue("RegisterHost")) == 0 {
|
if len(server.cfg.StringValue("registerHostname")) == 0 {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if len(server.cfg.StringValue("RegisterPassword")) == 0 {
|
if len(server.cfg.StringValue("registerPassword")) == 0 {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if len(server.cfg.StringValue("RegisterWebUrl")) == 0 {
|
if len(server.cfg.StringValue("registerUrl")) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !server.cfg.BoolValue("allowping") {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
|
|
@ -80,11 +83,11 @@ func (server *Server) RegisterPublicServer() {
|
||||||
|
|
||||||
// Render registration XML template
|
// Render registration XML template
|
||||||
reg := Register{
|
reg := Register{
|
||||||
Name: server.cfg.StringValue("RegisterName"),
|
Name: server.cfg.StringValue("registerName"),
|
||||||
Host: server.cfg.StringValue("RegisterHost"),
|
Host: server.cfg.StringValue("registerHostname"),
|
||||||
Password: server.cfg.StringValue("RegisterPassword"),
|
Password: server.cfg.StringValue("registerPassword"),
|
||||||
Url: server.cfg.StringValue("RegisterWebUrl"),
|
Url: server.cfg.StringValue("registerUrl"),
|
||||||
Location: server.cfg.StringValue("RegisterLocation"),
|
Location: server.cfg.StringValue("registerLocation"),
|
||||||
Port: server.CurrentPort(),
|
Port: server.CurrentPort(),
|
||||||
Digest: digest,
|
Digest: digest,
|
||||||
Users: len(server.clients),
|
Users: len(server.clients),
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,6 @@ import (
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
@ -75,8 +74,8 @@ type Server struct {
|
||||||
|
|
||||||
incoming chan *Message
|
incoming chan *Message
|
||||||
voicebroadcast chan *VoiceBroadcast
|
voicebroadcast chan *VoiceBroadcast
|
||||||
cfgUpdate chan *KeyValuePair
|
|
||||||
tempRemove chan *Channel
|
tempRemove chan *Channel
|
||||||
|
//cfgUpdate chan *KeyValuePair
|
||||||
|
|
||||||
// Signals to the server that a client has been successfully
|
// Signals to the server that a client has been successfully
|
||||||
// authenticated.
|
// authenticated.
|
||||||
|
|
@ -138,12 +137,16 @@ func (lf clientLogForwarder) Write(incoming []byte) (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allocate a new Murmur instance
|
// Allocate a new Murmur instance
|
||||||
func NewServer(id int64) (s *Server, err error) {
|
func NewServer(id int64, config *serverconf.Config) (s *Server, err error) {
|
||||||
s = new(Server)
|
s = new(Server)
|
||||||
|
|
||||||
s.Id = id
|
s.Id = id
|
||||||
|
|
||||||
s.cfg = serverconf.New(nil)
|
if config == nil {
|
||||||
|
s.cfg = serverconf.New(nil, nil)
|
||||||
|
} else {
|
||||||
|
s.cfg = config
|
||||||
|
}
|
||||||
|
|
||||||
s.Users = make(map[uint32]*User)
|
s.Users = make(map[uint32]*User)
|
||||||
s.UserCertMap = make(map[string]*User)
|
s.UserCertMap = make(map[string]*User)
|
||||||
|
|
@ -175,6 +178,16 @@ func (server *Server) RootChannel() *Channel {
|
||||||
return root
|
return root
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get a pointer to the default channel
|
||||||
|
func (server *Server) DefaultChannel() *Channel {
|
||||||
|
channel, exists := server.Channels[server.cfg.IntValue("defaultchannel")]
|
||||||
|
if !exists {
|
||||||
|
return server.RootChannel()
|
||||||
|
}
|
||||||
|
return channel
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set password as the new SuperUser password
|
||||||
func (server *Server) setConfigPassword(key, password string) {
|
func (server *Server) setConfigPassword(key, password string) {
|
||||||
saltBytes := make([]byte, 24)
|
saltBytes := make([]byte, 24)
|
||||||
_, err := rand.Read(saltBytes)
|
_, err := rand.Read(saltBytes)
|
||||||
|
|
@ -191,10 +204,6 @@ func (server *Server) setConfigPassword(key, password string) {
|
||||||
// Could be racy, but shouldn't really matter...
|
// Could be racy, but shouldn't really matter...
|
||||||
val := "sha1$" + salt + "$" + digest
|
val := "sha1$" + salt + "$" + digest
|
||||||
server.cfg.Set(key, val)
|
server.cfg.Set(key, val)
|
||||||
|
|
||||||
if server.cfgUpdate != nil {
|
|
||||||
server.cfgUpdate <- &KeyValuePair{Key: key, Value: val}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSuperUserPassword sets password as the new SuperUser password
|
// SetSuperUserPassword sets password as the new SuperUser password
|
||||||
|
|
@ -272,7 +281,14 @@ func (server *Server) handleIncomingClient(conn net.Conn) (err error) {
|
||||||
client.lf = &clientLogForwarder{client, server.Logger}
|
client.lf = &clientLogForwarder{client, server.Logger}
|
||||||
client.Logger = log.New(client.lf, "", 0)
|
client.Logger = log.New(client.lf, "", 0)
|
||||||
|
|
||||||
client.session = server.pool.Get()
|
client.session, err = server.pool.Get()
|
||||||
|
if err != nil {
|
||||||
|
// Server is full. Murmur just closes the connection here anyway,
|
||||||
|
// so don't bother sending a Reject_ServerFull
|
||||||
|
client.Printf("Server is full, rejecting %v", conn.RemoteAddr())
|
||||||
|
conn.Close()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
client.Printf("New connection: %v (%v)", conn.RemoteAddr(), client.Session())
|
client.Printf("New connection: %v (%v)", conn.RemoteAddr(), client.Session())
|
||||||
|
|
||||||
client.tcpaddr = addr.(*net.TCPAddr)
|
client.tcpaddr = addr.(*net.TCPAddr)
|
||||||
|
|
@ -436,14 +452,6 @@ func (server *Server) handlerLoop() {
|
||||||
case client := <-server.clientAuthenticated:
|
case client := <-server.clientAuthenticated:
|
||||||
server.finishAuthenticate(client)
|
server.finishAuthenticate(client)
|
||||||
|
|
||||||
// Disk freeze config update
|
|
||||||
case kvp := <-server.cfgUpdate:
|
|
||||||
if !kvp.Reset {
|
|
||||||
server.UpdateConfig(kvp.Key, kvp.Value)
|
|
||||||
} else {
|
|
||||||
server.ResetConfig(kvp.Key)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Server registration update
|
// Server registration update
|
||||||
// Tick every hour + a minute offset based on the server id.
|
// Tick every hour + a minute offset based on the server id.
|
||||||
case <-regtick:
|
case <-regtick:
|
||||||
|
|
@ -537,14 +545,15 @@ func (server *Server) handleAuthenticate(client *Client, msg *Message) {
|
||||||
client.user = user
|
client.user = user
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
// Otherwise, the user is unregistered. If there is a server-wide password, they now need it.
|
||||||
if client.user == nil && server.hasServerPassword() {
|
if client.user == nil && server.hasServerPassword() {
|
||||||
if auth.Password == nil || !server.CheckServerPassword(*auth.Password) {
|
if auth.Password == nil || !server.CheckServerPassword(*auth.Password) {
|
||||||
client.RejectAuth(mumbleproto.Reject_WrongServerPW, "Invalid server password")
|
client.RejectAuth(mumbleproto.Reject_WrongServerPW, "Invalid server password")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Setup the cryptstate for the client.
|
// Setup the cryptstate for the client.
|
||||||
err = client.crypt.GenerateKey(client.CryptoMode)
|
err = client.crypt.GenerateKey(client.CryptoMode)
|
||||||
|
|
@ -627,8 +636,8 @@ func (server *Server) finishAuthenticate(client *Client) {
|
||||||
server.hclients[host] = append(server.hclients[host], client)
|
server.hclients[host] = append(server.hclients[host], client)
|
||||||
server.hmutex.Unlock()
|
server.hmutex.Unlock()
|
||||||
|
|
||||||
channel := server.RootChannel()
|
channel := server.DefaultChannel()
|
||||||
if client.IsRegistered() {
|
if server.cfg.BoolValue("rememberchannel") && client.IsRegistered() {
|
||||||
lastChannel := server.Channels[client.user.LastChannelId]
|
lastChannel := server.Channels[client.user.LastChannelId]
|
||||||
if lastChannel != nil {
|
if lastChannel != nil {
|
||||||
channel = lastChannel
|
channel = lastChannel
|
||||||
|
|
@ -684,8 +693,8 @@ func (server *Server) finishAuthenticate(client *Client) {
|
||||||
|
|
||||||
sync := &mumbleproto.ServerSync{}
|
sync := &mumbleproto.ServerSync{}
|
||||||
sync.Session = proto.Uint32(client.Session())
|
sync.Session = proto.Uint32(client.Session())
|
||||||
sync.MaxBandwidth = proto.Uint32(server.cfg.Uint32Value("MaxBandwidth"))
|
sync.MaxBandwidth = proto.Uint32(server.cfg.Uint32Value("bandwidth"))
|
||||||
sync.WelcomeText = proto.String(server.cfg.StringValue("WelcomeText"))
|
sync.WelcomeText = proto.String(server.cfg.StringValue("welcometext"))
|
||||||
if client.IsSuperUser() {
|
if client.IsSuperUser() {
|
||||||
sync.Permissions = proto.Uint64(uint64(acl.AllPermissions))
|
sync.Permissions = proto.Uint64(uint64(acl.AllPermissions))
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -702,9 +711,9 @@ func (server *Server) finishAuthenticate(client *Client) {
|
||||||
}
|
}
|
||||||
|
|
||||||
err := client.sendMessage(&mumbleproto.ServerConfig{
|
err := client.sendMessage(&mumbleproto.ServerConfig{
|
||||||
AllowHtml: proto.Bool(server.cfg.BoolValue("AllowHTML")),
|
AllowHtml: proto.Bool(server.cfg.BoolValue("allowhtml")),
|
||||||
MessageLength: proto.Uint32(server.cfg.Uint32Value("MaxTextMessageLength")),
|
MessageLength: proto.Uint32(server.cfg.Uint32Value("textmessagelength")),
|
||||||
ImageMessageLength: proto.Uint32(server.cfg.Uint32Value("MaxImageMessageLength")),
|
ImageMessageLength: proto.Uint32(server.cfg.Uint32Value("imagemessagelength")),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
client.Panicf("%v", err)
|
client.Panicf("%v", err)
|
||||||
|
|
@ -1001,6 +1010,9 @@ func (server *Server) udpListenLoop() {
|
||||||
|
|
||||||
// Length 12 is for ping datagrams from the ConnectDialog.
|
// Length 12 is for ping datagrams from the ConnectDialog.
|
||||||
if nread == 12 {
|
if nread == 12 {
|
||||||
|
if !server.cfg.BoolValue("allowping") {
|
||||||
|
return
|
||||||
|
}
|
||||||
readbuf := bytes.NewBuffer(buf)
|
readbuf := bytes.NewBuffer(buf)
|
||||||
var (
|
var (
|
||||||
tmp32 uint32
|
tmp32 uint32
|
||||||
|
|
@ -1010,11 +1022,11 @@ func (server *Server) udpListenLoop() {
|
||||||
_ = binary.Read(readbuf, binary.BigEndian, &rand)
|
_ = binary.Read(readbuf, binary.BigEndian, &rand)
|
||||||
|
|
||||||
buffer := bytes.NewBuffer(make([]byte, 0, 24))
|
buffer := bytes.NewBuffer(make([]byte, 0, 24))
|
||||||
_ = binary.Write(buffer, binary.BigEndian, uint32((1<<16)|(2<<8)|2))
|
_ = binary.Write(buffer, binary.BigEndian, uint32((1<<16)|(2<<8)|4))
|
||||||
_ = binary.Write(buffer, binary.BigEndian, rand)
|
_ = binary.Write(buffer, binary.BigEndian, rand)
|
||||||
_ = binary.Write(buffer, binary.BigEndian, uint32(len(server.clients)))
|
_ = binary.Write(buffer, binary.BigEndian, uint32(len(server.clients)))
|
||||||
_ = binary.Write(buffer, binary.BigEndian, server.cfg.Uint32Value("MaxUsers"))
|
_ = binary.Write(buffer, binary.BigEndian, server.cfg.Uint32Value("users"))
|
||||||
_ = binary.Write(buffer, binary.BigEndian, server.cfg.Uint32Value("MaxBandwidth"))
|
_ = binary.Write(buffer, binary.BigEndian, server.cfg.Uint32Value("bandwidth"))
|
||||||
|
|
||||||
err = server.SendUDP(buffer.Bytes(), udpaddr)
|
err = server.SendUDP(buffer.Bytes(), udpaddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -1292,9 +1304,9 @@ func (server *Server) IsCertHashBanned(hash string) bool {
|
||||||
// Filter incoming text according to the server's current rules.
|
// Filter incoming text according to the server's current rules.
|
||||||
func (server *Server) FilterText(text string) (filtered string, err error) {
|
func (server *Server) FilterText(text string) (filtered string, err error) {
|
||||||
options := &htmlfilter.Options{
|
options := &htmlfilter.Options{
|
||||||
StripHTML: !server.cfg.BoolValue("AllowHTML"),
|
StripHTML: !server.cfg.BoolValue("allowhtml"),
|
||||||
MaxTextMessageLength: server.cfg.IntValue("MaxTextMessageLength"),
|
MaxTextMessageLength: server.cfg.IntValue("textmessagelength"),
|
||||||
MaxImageMessageLength: server.cfg.IntValue("MaxImageMessageLength"),
|
MaxImageMessageLength: server.cfg.IntValue("imagemessagelength"),
|
||||||
}
|
}
|
||||||
return htmlfilter.Filter(text, options)
|
return htmlfilter.Filter(text, options)
|
||||||
}
|
}
|
||||||
|
|
@ -1348,7 +1360,7 @@ func isTimeout(err error) bool {
|
||||||
|
|
||||||
// Initialize the per-launch data
|
// Initialize the per-launch data
|
||||||
func (server *Server) initPerLaunchData() {
|
func (server *Server) initPerLaunchData() {
|
||||||
server.pool = sessionpool.New()
|
server.pool = sessionpool.New(server.cfg.Uint32Value("users"))
|
||||||
server.clients = make(map[uint32]*Client)
|
server.clients = make(map[uint32]*Client)
|
||||||
server.hclients = make(map[string][]*Client)
|
server.hclients = make(map[string][]*Client)
|
||||||
server.hpclients = make(map[string]*Client)
|
server.hpclients = make(map[string]*Client)
|
||||||
|
|
@ -1356,7 +1368,6 @@ func (server *Server) initPerLaunchData() {
|
||||||
server.bye = make(chan bool)
|
server.bye = make(chan bool)
|
||||||
server.incoming = make(chan *Message)
|
server.incoming = make(chan *Message)
|
||||||
server.voicebroadcast = make(chan *VoiceBroadcast)
|
server.voicebroadcast = make(chan *VoiceBroadcast)
|
||||||
server.cfgUpdate = make(chan *KeyValuePair)
|
|
||||||
server.tempRemove = make(chan *Channel, 1)
|
server.tempRemove = make(chan *Channel, 1)
|
||||||
server.clientAuthenticated = make(chan *Client)
|
server.clientAuthenticated = make(chan *Client)
|
||||||
}
|
}
|
||||||
|
|
@ -1371,7 +1382,6 @@ func (server *Server) cleanPerLaunchData() {
|
||||||
server.bye = nil
|
server.bye = nil
|
||||||
server.incoming = nil
|
server.incoming = nil
|
||||||
server.voicebroadcast = nil
|
server.voicebroadcast = nil
|
||||||
server.cfgUpdate = nil
|
|
||||||
server.tempRemove = nil
|
server.tempRemove = nil
|
||||||
server.clientAuthenticated = nil
|
server.clientAuthenticated = nil
|
||||||
}
|
}
|
||||||
|
|
@ -1379,7 +1389,7 @@ func (server *Server) cleanPerLaunchData() {
|
||||||
// Port returns the port the native server will listen on when it is
|
// Port returns the port the native server will listen on when it is
|
||||||
// started.
|
// started.
|
||||||
func (server *Server) Port() int {
|
func (server *Server) Port() int {
|
||||||
port := server.cfg.IntValue("Port")
|
port := server.cfg.IntValue("port")
|
||||||
if port == 0 {
|
if port == 0 {
|
||||||
return DefaultPort + int(server.Id) - 1
|
return DefaultPort + int(server.Id) - 1
|
||||||
}
|
}
|
||||||
|
|
@ -1389,13 +1399,13 @@ func (server *Server) Port() int {
|
||||||
// ListenWebPort returns true if we should listen to the
|
// ListenWebPort returns true if we should listen to the
|
||||||
// web port, otherwise false
|
// web port, otherwise false
|
||||||
func (server *Server) ListenWebPort() bool {
|
func (server *Server) ListenWebPort() bool {
|
||||||
return !server.cfg.BoolValue("NoWebServer")
|
return !server.cfg.BoolValue("nowebserver")
|
||||||
}
|
}
|
||||||
|
|
||||||
// WebPort returns the port the web server will listen on when it is
|
// WebPort returns the port the web server will listen on when it is
|
||||||
// started.
|
// started.
|
||||||
func (server *Server) WebPort() int {
|
func (server *Server) WebPort() int {
|
||||||
port := server.cfg.IntValue("WebPort")
|
port := server.cfg.IntValue("webport")
|
||||||
if port == 0 {
|
if port == 0 {
|
||||||
return DefaultWebPort + int(server.Id) - 1
|
return DefaultWebPort + int(server.Id) - 1
|
||||||
}
|
}
|
||||||
|
|
@ -1417,7 +1427,7 @@ func (server *Server) CurrentPort() int {
|
||||||
// it is started. This must be an IP address, either IPv4
|
// it is started. This must be an IP address, either IPv4
|
||||||
// or IPv6.
|
// or IPv6.
|
||||||
func (server *Server) HostAddress() string {
|
func (server *Server) HostAddress() string {
|
||||||
host := server.cfg.StringValue("Address")
|
host := server.cfg.StringValue("host")
|
||||||
if host == "" {
|
if host == "" {
|
||||||
return "0.0.0.0"
|
return "0.0.0.0"
|
||||||
}
|
}
|
||||||
|
|
@ -1460,8 +1470,8 @@ func (server *Server) Start() (err error) {
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// Wrap a TLS listener around the TCP connection
|
// Wrap a TLS listener around the TCP connection
|
||||||
certFn := filepath.Join(Args.DataDir, "cert.pem")
|
certFn := server.cfg.PathValue("sslCert", Args.DataDir)
|
||||||
keyFn := filepath.Join(Args.DataDir, "key.pem")
|
keyFn := server.cfg.PathValue("sslKey", Args.DataDir)
|
||||||
cert, err := tls.LoadX509KeyPair(certFn, keyFn)
|
cert, err := tls.LoadX509KeyPair(certFn, keyFn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -1470,6 +1480,15 @@ func (server *Server) Start() (err error) {
|
||||||
Certificates: []tls.Certificate{cert},
|
Certificates: []tls.Certificate{cert},
|
||||||
ClientAuth: tls.RequestClientCert,
|
ClientAuth: tls.RequestClientCert,
|
||||||
}
|
}
|
||||||
|
ciphersstr := server.cfg.StringValue("sslCiphers")
|
||||||
|
if ciphersstr != "" {
|
||||||
|
var invalid []string
|
||||||
|
server.tlscfg.CipherSuites, invalid = serverconf.ParseCipherlist(ciphersstr)
|
||||||
|
for _, cipher := range invalid {
|
||||||
|
log.Printf("Ignoring invalid or unsupported cipher \"%v\"", cipher)
|
||||||
|
}
|
||||||
|
server.tlscfg.PreferServerCipherSuites = true
|
||||||
|
}
|
||||||
server.tlsl = tls.NewListener(server.tcpl, server.tlscfg)
|
server.tlsl = tls.NewListener(server.tcpl, server.tlscfg)
|
||||||
|
|
||||||
if shouldListenWeb {
|
if shouldListenWeb {
|
||||||
|
|
|
||||||
1
go.mod
1
go.mod
|
|
@ -6,4 +6,5 @@ require (
|
||||||
github.com/golang/protobuf v1.3.5
|
github.com/golang/protobuf v1.3.5
|
||||||
github.com/gorilla/websocket v1.4.2
|
github.com/gorilla/websocket v1.4.2
|
||||||
golang.org/x/crypto v0.0.0-20200406173513-056763e48d71
|
golang.org/x/crypto v0.0.0-20200406173513-056763e48d71
|
||||||
|
gopkg.in/ini.v1 v1.55.0
|
||||||
)
|
)
|
||||||
|
|
|
||||||
2
go.sum
2
go.sum
|
|
@ -10,3 +10,5 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h
|
||||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI=
|
golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI=
|
||||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
|
gopkg.in/ini.v1 v1.55.0 h1:E8yzL5unfpW3M6fz/eB7Cb5MQAYSZ7GKo4Qth+N2sgQ=
|
||||||
|
gopkg.in/ini.v1 v1.55.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||||
|
|
|
||||||
77
pkg/serverconf/cipherlist.go
Normal file
77
pkg/serverconf/cipherlist.go
Normal file
|
|
@ -0,0 +1,77 @@
|
||||||
|
package serverconf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var cipherLookup = map[string]uint16{
|
||||||
|
// RFC
|
||||||
|
"TLS_RSA_WITH_RC4_128_SHA": tls.TLS_RSA_WITH_RC4_128_SHA,
|
||||||
|
"TLS_RSA_WITH_3DES_EDE_CBC_SHA": tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||||
|
"TLS_RSA_WITH_AES_128_CBC_SHA": tls.TLS_RSA_WITH_AES_128_CBC_SHA,
|
||||||
|
"TLS_RSA_WITH_AES_256_CBC_SHA": tls.TLS_RSA_WITH_AES_256_CBC_SHA,
|
||||||
|
"TLS_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_RSA_WITH_AES_128_CBC_SHA256,
|
||||||
|
"TLS_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
|
||||||
|
"TLS_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
|
||||||
|
"TLS_ECDHE_ECDSA_WITH_RC4_128_SHA": tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
|
||||||
|
"TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
|
||||||
|
"TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
|
||||||
|
"TLS_ECDHE_RSA_WITH_RC4_128_SHA": tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA,
|
||||||
|
"TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||||
|
"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
|
||||||
|
"TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
|
||||||
|
"TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
|
||||||
|
"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
|
||||||
|
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||||
|
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||||
|
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||||
|
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||||
|
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
|
||||||
|
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
|
||||||
|
// These are the actual names per RFC 7905
|
||||||
|
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
|
||||||
|
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
|
||||||
|
|
||||||
|
// OpenSSL
|
||||||
|
"RC4-SHA": tls.TLS_RSA_WITH_RC4_128_SHA,
|
||||||
|
"DES-CBC3-SHA": tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||||
|
"AES128-SHA": tls.TLS_RSA_WITH_AES_128_CBC_SHA,
|
||||||
|
"AES256-SHA": tls.TLS_RSA_WITH_AES_256_CBC_SHA,
|
||||||
|
"AES128-SHA256": tls.TLS_RSA_WITH_AES_128_CBC_SHA256,
|
||||||
|
"AES128-GCM-SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
|
||||||
|
"AES256-GCM-SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
|
||||||
|
"ECDHE-ECDSA-RC4-SHA": tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
|
||||||
|
"ECDHE-ECDSA-AES128-SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
|
||||||
|
"ECDHE-ECDSA-AES256-SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
|
||||||
|
"ECDHE-RSA-RC4-SHA": tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA,
|
||||||
|
"ECDHE-RSA-DES-CBC3-SHA": tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||||
|
"ECDHE-RSA-AES128-SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
|
||||||
|
"ECDHE-RSA-AES256-SHA": tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
|
||||||
|
"ECDHE-ECDSA-AES128-SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
|
||||||
|
"ECDHE-RSA-AES128-SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
|
||||||
|
"ECDHE-RSA-AES128-GCM-SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||||
|
"ECDHE-ECDSA-AES128-GCM-SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||||
|
"ECDHE-RSA-AES256-GCM-SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||||
|
"ECDHE-ECDSA-AES256-GCM-SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||||
|
"ECDHE-RSA-CHACHA20-POLY1305": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
|
||||||
|
"ECDHE-ECDSA-CHACHA20-POLY1305": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseCipherlist parses a list of cipher suites separated by colons.
|
||||||
|
// It supports both RFC and OpenSSL names, but does not support OpenSSL
|
||||||
|
// cipher strings representing categories of cipher suites.
|
||||||
|
func ParseCipherlist(list string) (ciphers []uint16, invalid []string) {
|
||||||
|
strciphers := strings.Split(list, ":")
|
||||||
|
ciphers = make([]uint16, 0, len(strciphers))
|
||||||
|
invalid = make([]string, 0)
|
||||||
|
for _, v := range strciphers {
|
||||||
|
c, ok := cipherLookup[v]
|
||||||
|
if ok {
|
||||||
|
ciphers = append(ciphers, c)
|
||||||
|
} else {
|
||||||
|
invalid = append(invalid, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
@ -5,44 +5,51 @@
|
||||||
package serverconf
|
package serverconf
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
var defaultCfg = map[string]string{
|
var defaultCfg = map[string]string{
|
||||||
"MaxBandwidth": "72000",
|
"bandwidth": "72000",
|
||||||
"MaxUsers": "1000",
|
"users": "1000",
|
||||||
"MaxUsersPerChannel": "0",
|
"usersperchannel": "0",
|
||||||
"MaxTextMessageLength": "5000",
|
"textmessagelength": "5000",
|
||||||
"MaxImageMessageLength": "131072",
|
"imagemessagelength": "131072",
|
||||||
"AllowHTML": "true",
|
"allowhtml": "true",
|
||||||
"DefaultChannel": "0",
|
"defaultchannel": "0",
|
||||||
"RememberChannel": "true",
|
"rememberchannel": "true",
|
||||||
"WelcomeText": "Welcome to this server running <b>Grumble</b>.",
|
"welcometext": "Welcome to this server running <b>Grumble</b>.",
|
||||||
"SendVersion": "true",
|
"sendversion": "true",
|
||||||
|
"allowping": "true",
|
||||||
|
"logfile": "grumble.log",
|
||||||
|
"sslCert": "cert.pem",
|
||||||
|
"sslKey": "key.pem",
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
cfgMap map[string]string
|
fallbackMap map[string]string
|
||||||
|
persistentMap map[string]string
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a new Config using cfgMap as the intial internal config map.
|
// New returns a new Config using persistentMap as the initial internal config map.
|
||||||
// If cfgMap is nil, ConfigWithMap will create a new config map.
|
// The map persistentMap may not be reused. If set to nil, a new map is created.
|
||||||
func New(cfgMap map[string]string) *Config {
|
// Optionally, defaults may be passed in fallbackMap. This map is only read, not written.
|
||||||
if cfgMap == nil {
|
func New(persistentMap, fallbackMap map[string]string) *Config {
|
||||||
cfgMap = make(map[string]string)
|
if persistentMap == nil {
|
||||||
|
persistentMap = make(map[string]string)
|
||||||
}
|
}
|
||||||
return &Config{cfgMap: cfgMap}
|
return &Config{persistentMap: persistentMap, fallbackMap: fallbackMap}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAll gets a copy of the Config's internal config map
|
// GetAllPersistent returns a copy of the internal persistent key-value map.
|
||||||
func (cfg *Config) GetAll() (all map[string]string) {
|
func (cfg *Config) GetAllPersistent() (all map[string]string) {
|
||||||
cfg.mutex.RLock()
|
cfg.mutex.RLock()
|
||||||
defer cfg.mutex.RUnlock()
|
defer cfg.mutex.RUnlock()
|
||||||
|
|
||||||
all = make(map[string]string)
|
all = make(map[string]string)
|
||||||
for k, v := range cfg.cfgMap {
|
for k, v := range cfg.persistentMap {
|
||||||
all[k] = v
|
all[k] = v
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
@ -52,14 +59,14 @@ func (cfg *Config) GetAll() (all map[string]string) {
|
||||||
func (cfg *Config) Set(key string, value string) {
|
func (cfg *Config) Set(key string, value string) {
|
||||||
cfg.mutex.Lock()
|
cfg.mutex.Lock()
|
||||||
defer cfg.mutex.Unlock()
|
defer cfg.mutex.Unlock()
|
||||||
cfg.cfgMap[key] = value
|
cfg.persistentMap[key] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset the value of a config key
|
// Reset the value of a config key
|
||||||
func (cfg *Config) Reset(key string) {
|
func (cfg *Config) Reset(key string) {
|
||||||
cfg.mutex.Lock()
|
cfg.mutex.Lock()
|
||||||
defer cfg.mutex.Unlock()
|
defer cfg.mutex.Unlock()
|
||||||
delete(cfg.cfgMap, key)
|
delete(cfg.persistentMap, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
// StringValue gets the value of a specific config key encoded as a string
|
// StringValue gets the value of a specific config key encoded as a string
|
||||||
|
|
@ -67,7 +74,12 @@ func (cfg *Config) StringValue(key string) (value string) {
|
||||||
cfg.mutex.RLock()
|
cfg.mutex.RLock()
|
||||||
defer cfg.mutex.RUnlock()
|
defer cfg.mutex.RUnlock()
|
||||||
|
|
||||||
value, exists := cfg.cfgMap[key]
|
value, exists := cfg.persistentMap[key]
|
||||||
|
if exists {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
value, exists = cfg.fallbackMap[key]
|
||||||
if exists {
|
if exists {
|
||||||
return value
|
return value
|
||||||
}
|
}
|
||||||
|
|
@ -80,7 +92,7 @@ func (cfg *Config) StringValue(key string) (value string) {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// IntValue gets the value of a speific config key as an int
|
// Get the value of a specific config key as an int
|
||||||
func (cfg *Config) IntValue(key string) (intval int) {
|
func (cfg *Config) IntValue(key string) (intval int) {
|
||||||
str := cfg.StringValue(key)
|
str := cfg.StringValue(key)
|
||||||
intval, _ = strconv.Atoi(str)
|
intval, _ = strconv.Atoi(str)
|
||||||
|
|
@ -94,9 +106,19 @@ func (cfg *Config) Uint32Value(key string) (uint32val uint32) {
|
||||||
return uint32(uintval)
|
return uint32(uintval)
|
||||||
}
|
}
|
||||||
|
|
||||||
// BoolValue gets the value fo a sepcific config key as a bool
|
// Get the value of a specific config key as a bool
|
||||||
func (cfg *Config) BoolValue(key string) (boolval bool) {
|
func (cfg *Config) BoolValue(key string) (boolval bool) {
|
||||||
str := cfg.StringValue(key)
|
str := cfg.StringValue(key)
|
||||||
boolval, _ = strconv.ParseBool(str)
|
boolval, _ = strconv.ParseBool(str)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get the value of a specific config key as a path,
|
||||||
|
// joined with the path in rel if not absolute.
|
||||||
|
func (cfg *Config) PathValue(key string, rel string) (path string) {
|
||||||
|
str := cfg.StringValue(key)
|
||||||
|
if filepath.IsAbs(str) {
|
||||||
|
return filepath.Clean(str)
|
||||||
|
}
|
||||||
|
return filepath.Join(rel, str)
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestIntValue(t *testing.T) {
|
func TestIntValue(t *testing.T) {
|
||||||
cfg := New(nil)
|
cfg := New(nil, nil)
|
||||||
cfg.Set("Test", "13")
|
cfg.Set("Test", "13")
|
||||||
if cfg.IntValue("Test") != 13 {
|
if cfg.IntValue("Test") != 13 {
|
||||||
t.Errorf("Expected 13")
|
t.Errorf("Expected 13")
|
||||||
|
|
@ -17,7 +17,7 @@ func TestIntValue(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFloatAsInt(t *testing.T) {
|
func TestFloatAsInt(t *testing.T) {
|
||||||
cfg := New(nil)
|
cfg := New(nil, nil)
|
||||||
cfg.Set("Test", "13.4")
|
cfg.Set("Test", "13.4")
|
||||||
if cfg.IntValue("Test") != 0 {
|
if cfg.IntValue("Test") != 0 {
|
||||||
t.Errorf("Expected 0")
|
t.Errorf("Expected 0")
|
||||||
|
|
@ -25,14 +25,14 @@ func TestFloatAsInt(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDefaultValue(t *testing.T) {
|
func TestDefaultValue(t *testing.T) {
|
||||||
cfg := New(nil)
|
cfg := New(nil, nil)
|
||||||
if cfg.IntValue("MaxBandwidth") != 72000 {
|
if cfg.IntValue("bandwidth") != 72000 {
|
||||||
t.Errorf("Expected 72000")
|
t.Errorf("Expected 72000")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBoolValue(t *testing.T) {
|
func TestBoolValue(t *testing.T) {
|
||||||
cfg := New(nil)
|
cfg := New(nil, nil)
|
||||||
cfg.Set("DoStuffOnStartup", "true")
|
cfg.Set("DoStuffOnStartup", "true")
|
||||||
if cfg.BoolValue("DoStuffOnStartup") != true {
|
if cfg.BoolValue("DoStuffOnStartup") != true {
|
||||||
t.Errorf("Expected true")
|
t.Errorf("Expected true")
|
||||||
|
|
|
||||||
54
pkg/serverconf/file.go
Normal file
54
pkg/serverconf/file.go
Normal file
|
|
@ -0,0 +1,54 @@
|
||||||
|
package serverconf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
type cfg interface {
|
||||||
|
// GlobalMap returns a copy of the top-level (global) configuration map.
|
||||||
|
GlobalMap() map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConfigFile struct {
|
||||||
|
cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConfigFile(path string) (*ConfigFile, error) {
|
||||||
|
var f cfg
|
||||||
|
f, err := newinicfg(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &ConfigFile{f}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GlobalConfig returns a new *serverconf.Config representing the top-level
|
||||||
|
// (global) configuration.
|
||||||
|
func (c *ConfigFile) GlobalConfig() *Config {
|
||||||
|
return New(nil, c.GlobalMap())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServerConfig returns a new *serverconf.Config with the fallback representing
|
||||||
|
// the global configuration with server-specific values incremented by id.
|
||||||
|
// Optionally a persistent map which has priority may be passed. This map
|
||||||
|
// is consumed and cannot be reused.
|
||||||
|
func (c *ConfigFile) ServerConfig(id int64, persistentMap map[string]string) *Config {
|
||||||
|
m := c.GlobalMap()
|
||||||
|
|
||||||
|
// Some server specific values from the global config must be offset.
|
||||||
|
// These are read differently by the server as well.
|
||||||
|
if v, ok := m["port"]; ok {
|
||||||
|
i, err := strconv.ParseInt(v, 10, 64)
|
||||||
|
if err == nil {
|
||||||
|
m["port"] = strconv.FormatInt(i+id-1, 10)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v, ok := m["webport"]; ok {
|
||||||
|
i, err := strconv.ParseInt(v, 10, 64)
|
||||||
|
if err == nil {
|
||||||
|
m["webport"] = strconv.FormatInt(i+id-1, 10)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return New(persistentMap, m)
|
||||||
|
}
|
||||||
90
pkg/serverconf/file_ini.go
Normal file
90
pkg/serverconf/file_ini.go
Normal file
|
|
@ -0,0 +1,90 @@
|
||||||
|
package serverconf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"gopkg.in/ini.v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
type inicfg struct {
|
||||||
|
file *ini.File
|
||||||
|
}
|
||||||
|
|
||||||
|
func newinicfg(path string) (*inicfg, error) {
|
||||||
|
file, err := ini.LoadSources(ini.LoadOptions{AllowBooleanKeys: true, UnescapeValueDoubleQuotes: true}, path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
file.BlockMode = false // read only, avoid locking
|
||||||
|
return &inicfg{file}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *inicfg) GlobalMap() map[string]string {
|
||||||
|
return f.file.Section("").KeysHash()
|
||||||
|
}
|
||||||
|
|
||||||
|
var DefaultConfigFile = `# Grumble configuration file.
|
||||||
|
#
|
||||||
|
# The commented out settings represent the defaults.
|
||||||
|
# Options here may be overridden by virtual server specific configuration.
|
||||||
|
# Make sure to enclose values containing # or ; in double quotes or backticks.
|
||||||
|
|
||||||
|
# Address to bind the listeners to.
|
||||||
|
#host = 0.0.0.0
|
||||||
|
|
||||||
|
# port is the port to bind the native Mumble protocol to.
|
||||||
|
# webport is the port to bind the WebSocket Mumble protocol to.
|
||||||
|
# They are incremented for each virtual server (if set globally).
|
||||||
|
#port = 64738
|
||||||
|
#webport = 443
|
||||||
|
|
||||||
|
# Whether to disable web server.
|
||||||
|
#nowebserver
|
||||||
|
|
||||||
|
# "Message of the day" HTML string sent to connecting clients.
|
||||||
|
#welcometext = "Welcome to this server running <b>Grumble</b>."
|
||||||
|
|
||||||
|
# Password to join the server.
|
||||||
|
#serverpassword =
|
||||||
|
|
||||||
|
# Maximum bandwidth (in bits per second) per client for voice.
|
||||||
|
# Grumble does not yet enforce this limit, but some clients nicely follow it.
|
||||||
|
#bandwidth = 72000
|
||||||
|
|
||||||
|
# Maximum number of concurrent clients.
|
||||||
|
#users = 1000
|
||||||
|
#usersperchannel = 0
|
||||||
|
|
||||||
|
#textmessagelength = 5000
|
||||||
|
#imagemessagelength = 131072
|
||||||
|
#allowhtml
|
||||||
|
|
||||||
|
# The default channel is the channel (by ID) new users join.
|
||||||
|
# The root channel (ID = 0) is the default.
|
||||||
|
#defaultchannel = 0
|
||||||
|
|
||||||
|
# Whether users will rejoin the last channel they were in.
|
||||||
|
#rememberchannel
|
||||||
|
|
||||||
|
# Whether to include server OS info in ping response.
|
||||||
|
#sendversion
|
||||||
|
|
||||||
|
# Whether to respond to pings from the Connect dialog.
|
||||||
|
#allowping
|
||||||
|
|
||||||
|
# Path to the log file (relative to the data directory).
|
||||||
|
#logfile = grumble.log
|
||||||
|
|
||||||
|
# Path to TLS certificate and key (relative to the data directory).
|
||||||
|
# The certificate needs to have the entire chain concatenated to be valid.
|
||||||
|
# If these paths do not exist, Grumble will autogenerate a certificate.
|
||||||
|
#sslCert = cert.pem
|
||||||
|
#sslKey = key.pem
|
||||||
|
|
||||||
|
# Options for public server registration.
|
||||||
|
# All of these have to be set to make the server public.
|
||||||
|
# registerName additionally sets the name of the root channel.
|
||||||
|
# registerPassword is a simple, arbitrary secret to guard your registration. Don't lose it.
|
||||||
|
#registerName =
|
||||||
|
#registerHostname =
|
||||||
|
#registerPassword =
|
||||||
|
#registerUrl =
|
||||||
|
`
|
||||||
|
|
@ -6,6 +6,7 @@
|
||||||
package sessionpool
|
package sessionpool
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"math"
|
"math"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
@ -17,11 +18,17 @@ type SessionPool struct {
|
||||||
used map[uint32]bool
|
used map[uint32]bool
|
||||||
unused []uint32
|
unused []uint32
|
||||||
cur uint32
|
cur uint32
|
||||||
|
max uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a new SessionPool container.
|
// Create a new SessionPool container.
|
||||||
func New() (pool *SessionPool) {
|
func New(max uint32) (pool *SessionPool) {
|
||||||
pool = new(SessionPool)
|
pool = new(SessionPool)
|
||||||
|
if max == 0 {
|
||||||
|
pool.max = math.MaxUint32
|
||||||
|
} else {
|
||||||
|
pool.max = max
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -41,7 +48,7 @@ func (pool *SessionPool) EnableUseTracking() {
|
||||||
|
|
||||||
// Get a new session ID from the SessionPool.
|
// Get a new session ID from the SessionPool.
|
||||||
// Must be reclaimed using Reclaim() when done using it.
|
// Must be reclaimed using Reclaim() when done using it.
|
||||||
func (pool *SessionPool) Get() (id uint32) {
|
func (pool *SessionPool) Get() (id uint32, err error) {
|
||||||
pool.mutex.Lock()
|
pool.mutex.Lock()
|
||||||
defer pool.mutex.Unlock()
|
defer pool.mutex.Unlock()
|
||||||
|
|
||||||
|
|
@ -60,11 +67,12 @@ func (pool *SessionPool) Get() (id uint32) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for depletion. If cur is MaxUint32,
|
// Check for depletion. If cur is max,
|
||||||
// there aren't any session IDs left, since the
|
// there aren't any session IDs left, since the
|
||||||
// increment below would overflow us back to 0.
|
// increment below would return an out of range ID.
|
||||||
if pool.cur == math.MaxUint32 {
|
if pool.cur == pool.max {
|
||||||
panic("SessionPool depleted")
|
err = errors.New("depleted session pool")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Increment the next session id and return it.
|
// Increment the next session id and return it.
|
||||||
|
|
|
||||||
|
|
@ -6,35 +6,41 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestReclaim(t *testing.T) {
|
func TestReclaim(t *testing.T) {
|
||||||
pool := New()
|
pool := New(2)
|
||||||
id := pool.Get()
|
id, err := pool.Get()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no error: %v", err)
|
||||||
|
}
|
||||||
if id != 1 {
|
if id != 1 {
|
||||||
t.Errorf("Got %v, expected 1 (first time)", id)
|
t.Errorf("Got %v, expected 1 (first time)", id)
|
||||||
}
|
}
|
||||||
|
|
||||||
pool.Reclaim(1)
|
pool.Reclaim(1)
|
||||||
|
|
||||||
id = pool.Get()
|
id, err = pool.Get()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no error: %v", err)
|
||||||
|
}
|
||||||
if id != 1 {
|
if id != 1 {
|
||||||
t.Errorf("Got %v, expected 1 (second time)", id)
|
t.Errorf("Got %v, expected 1 (second time)", id)
|
||||||
}
|
}
|
||||||
|
|
||||||
id = pool.Get()
|
id, err = pool.Get()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no error: %v", err)
|
||||||
|
}
|
||||||
if id != 2 {
|
if id != 2 {
|
||||||
t.Errorf("Got %v, expected 2", id)
|
t.Errorf("Got %v, expected 2", id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDepletion(t *testing.T) {
|
func TestDepletion(t *testing.T) {
|
||||||
defer func() {
|
pool := New(0)
|
||||||
r := recover()
|
|
||||||
if r != "SessionPool depleted" {
|
|
||||||
t.Errorf("Expected depletion panic")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
pool := New()
|
|
||||||
pool.cur = math.MaxUint32
|
pool.cur = math.MaxUint32
|
||||||
pool.Get()
|
_, err := pool.Get()
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected depletion error")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUseTracking(t *testing.T) {
|
func TestUseTracking(t *testing.T) {
|
||||||
|
|
@ -45,7 +51,7 @@ func TestUseTracking(t *testing.T) {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
pool := New()
|
pool := New(0)
|
||||||
pool.EnableUseTracking()
|
pool.EnableUseTracking()
|
||||||
pool.Reclaim(42)
|
pool.Reclaim(42)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue