Remove global command line arguments variable, and only define command line flags from the main function.

This commit is contained in:
Ola Bini 2020-03-26 18:40:23 +00:00
parent a6dc45193a
commit ccdbb58cb6
No known key found for this signature in database
GPG key ID: 6786A150F6A2B28F
8 changed files with 53 additions and 45 deletions

View file

@ -85,16 +85,20 @@ func Usage() {
} }
} }
var Args args func readCommandlineArguments() args {
var a args
func init() {
flag.Usage = Usage flag.Usage = Usage
flag.BoolVar(&Args.ShowHelp, "help", false, "") flag.BoolVar(&a.ShowHelp, "help", false, "")
flag.StringVar(&Args.DataDir, "datadir", defaultDataDir(), "") flag.StringVar(&a.DataDir, "datadir", defaultDataDir(), "")
flag.StringVar(&Args.LogPath, "log", defaultLogPath(), "") flag.StringVar(&a.LogPath, "log", defaultLogPath(), "")
flag.BoolVar(&Args.RegenKeys, "regen-keys", false, "") flag.BoolVar(&a.RegenKeys, "regen-keys", false, "")
flag.StringVar(&Args.SQLiteDB, "import-murmurdb", "", "") flag.StringVar(&a.SQLiteDB, "import-murmurdb", "", "")
flag.BoolVar(&Args.CleanUp, "cleanup", false, "") flag.BoolVar(&a.CleanUp, "cleanup", false, "")
flag.Parse()
return a
} }

View file

@ -52,7 +52,7 @@ func (server *Server) openFreezeLog() error {
server.freezelog = nil server.freezelog = nil
} }
logfn := filepath.Join(Args.DataDir, "servers", strconv.FormatInt(server.Id, 10), "log.fz") logfn := filepath.Join(server.DataDir, "servers", strconv.FormatInt(server.Id, 10), "log.fz")
err := os.Remove(logfn) err := os.Remove(logfn)
if os.IsNotExist(err) { if os.IsNotExist(err) {
// fallthrough // fallthrough
@ -374,13 +374,13 @@ func FreezeGroup(group acl.Group) (*freezer.Group, error) {
// Once both the full server and the log file has been merged together // Once both the full server and the log file has been merged together
// in memory, a new full seralized server will be written and synced to // in memory, a new full seralized server will be written and synced to
// disk, and the existing log file will be removed. // disk, and the existing log file will be removed.
func NewServerFromFrozen(name string) (s *Server, err error) { func NewServerFromFrozen(name string, dataDir string) (s *Server, err error) {
id, err := strconv.ParseInt(name, 10, 64) id, err := strconv.ParseInt(name, 10, 64)
if err != nil { if err != nil {
return nil, err return nil, err
} }
path := filepath.Join(Args.DataDir, "servers", name) path := filepath.Join(dataDir, "servers", name)
mainFile := filepath.Join(path, "main.fz") mainFile := filepath.Join(path, "main.fz")
backupFile := filepath.Join(path, "backup.fz") backupFile := filepath.Join(path, "backup.fz")
logFn := filepath.Join(path, "log.fz") logFn := filepath.Join(path, "log.fz")
@ -424,6 +424,7 @@ func NewServerFromFrozen(name string) (s *Server, err error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.DataDir = dataDir
s.cfg = serverconf.New(cfgMap) s.cfg = serverconf.New(cfgMap)
// Unfreeze the server's frozen bans. // Unfreeze the server's frozen bans.

View file

@ -30,7 +30,7 @@ func (server *Server) freezeToFile() (err error) {
if err != nil { if err != nil {
return err return err
} }
f, err := ioutil.TempFile(filepath.Join(Args.DataDir, "servers", strconv.FormatInt(server.Id, 10)), ".main.fz_") f, err := ioutil.TempFile(filepath.Join(server.DataDir, "servers", strconv.FormatInt(server.Id, 10)), ".main.fz_")
if err != nil { if err != nil {
return err return err
} }
@ -50,7 +50,7 @@ func (server *Server) freezeToFile() (err error) {
if err != nil { if err != nil {
return err return err
} }
err = os.Rename(f.Name(), filepath.Join(Args.DataDir, "servers", strconv.FormatInt(server.Id, 10), "main.fz")) err = os.Rename(f.Name(), filepath.Join(server.DataDir, "servers", strconv.FormatInt(server.Id, 10), "main.fz"))
if err != nil { if err != nil {
return err return err
} }

View file

@ -29,7 +29,7 @@ func (server *Server) freezeToFile() (err error) {
if err != nil { if err != nil {
return err return err
} }
f, err := ioutil.TempFile(filepath.Join(Args.DataDir, "servers", strconv.FormatInt(server.Id, 10)), ".main.fz_") f, err := ioutil.TempFile(filepath.Join(server.DataDir, "servers", strconv.FormatInt(server.Id, 10)), ".main.fz_")
if err != nil { if err != nil {
return err return err
} }
@ -51,8 +51,8 @@ func (server *Server) freezeToFile() (err error) {
} }
src := f.Name() src := f.Name()
dst := filepath.Join(Args.DataDir, "servers", strconv.FormatInt(server.Id, 10), "main.fz") dst := filepath.Join(server.DataDir, "servers", strconv.FormatInt(server.Id, 10), "main.fz")
backup := filepath.Join(Args.DataDir, "servers", strconv.FormatInt(server.Id, 10), "backup.fz") backup := filepath.Join(server.DataDir, "servers", strconv.FormatInt(server.Id, 10), "backup.fz")
err = replacefile.ReplaceFile(dst, src, backup, replacefile.Flag(0)) err = replacefile.ReplaceFile(dst, src, backup, replacefile.Flag(0))
// If the dst file does not exist (as in, on first launch) // If the dst file does not exist (as in, on first launch)

View file

@ -20,7 +20,7 @@ import (
// Generate a 4096-bit RSA keypair and a Grumble auto-generated X509 // Generate a 4096-bit RSA keypair and a Grumble auto-generated X509
// certificate. Output PEM-encoded DER representations of the resulting // certificate. Output PEM-encoded DER representations of the resulting
// certificate and private key to certpath and keypath. // certificate and private key to certpath and keypath.
func GenerateSelfSignedCert(certpath, keypath string) (err error) { func GenerateSelfSignedCert(certpath, keypath string, dataDir string) (err error) {
now := time.Now() now := time.Now()
tmpl := &x509.Certificate{ tmpl := &x509.Certificate{
SerialNumber: big.NewInt(0), SerialNumber: big.NewInt(0),
@ -56,7 +56,7 @@ func GenerateSelfSignedCert(certpath, keypath string) (err error) {
Bytes: keybuf, Bytes: keybuf,
} }
certfn := filepath.Join(Args.DataDir, "cert.pem") certfn := filepath.Join(dataDir, "cert.pem")
file, err := os.OpenFile(certfn, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0700) file, err := os.OpenFile(certfn, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0700)
if err != nil { if err != nil {
return err return err
@ -67,7 +67,7 @@ func GenerateSelfSignedCert(certpath, keypath string) (err error) {
return err return err
} }
keyfn := filepath.Join(Args.DataDir, "key.pem") keyfn := filepath.Join(dataDir, "key.pem")
file, err = os.OpenFile(keyfn, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0700) file, err = os.OpenFile(keyfn, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0700)
if err != nil { if err != nil {
return err return err

View file

@ -5,7 +5,6 @@
package main package main
import ( import (
"flag"
"fmt" "fmt"
"log" "log"
"os" "os"
@ -22,31 +21,31 @@ var blobStore blobstore.BlobStore
func main() { func main() {
var err error var err error
flag.Parse() cmdArgs := readCommandlineArguments()
if Args.ShowHelp == true { if cmdArgs.ShowHelp == true {
Usage() Usage()
return return
} }
// Open the data dir to check whether it exists. // Open the data dir to check whether it exists.
dataDir, err := os.Open(Args.DataDir) dataDir, err := os.Open(cmdArgs.DataDir)
if err != nil { if err != nil {
log.Fatalf("Unable to open data directory (%v): %v", Args.DataDir, err) log.Fatalf("Unable to open data directory (%v): %v", cmdArgs.DataDir, err)
return return
} }
dataDir.Close() dataDir.Close()
// Set up logging // Set up logging
err = logtarget.Target.OpenFile(Args.LogPath) err = logtarget.Target.OpenFile(cmdArgs.LogPath)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "Unable to open log file (%v): %v", Args.LogPath, err) fmt.Fprintf(os.Stderr, "Unable to open log file (%v): %v", cmdArgs.LogPath, err)
return return
} }
log.SetPrefix("[G] ") log.SetPrefix("[G] ")
log.SetFlags(log.LstdFlags | log.Lmicroseconds) log.SetFlags(log.LstdFlags | log.Lmicroseconds)
log.SetOutput(&logtarget.Target) log.SetOutput(&logtarget.Target)
log.Printf("Grumble") log.Printf("Grumble")
log.Printf("Using data directory: %s", Args.DataDir) log.Printf("Using data directory: %s", cmdArgs.DataDir)
// Open the blobstore. If the directory doesn't // Open the blobstore. If the directory doesn't
// already exist, create the directory and open // already exist, create the directory and open
@ -54,7 +53,7 @@ func main() {
// The Open method of the blobstore performs simple // The Open method of the blobstore performs simple
// sanity checking of content of the blob directory, // sanity checking of content of the blob directory,
// and will return an error if something's amiss. // and will return an error if something's amiss.
blobDir := filepath.Join(Args.DataDir, "blob") blobDir := filepath.Join(cmdArgs.DataDir, "blob")
err = os.Mkdir(blobDir, 0700) err = os.Mkdir(blobDir, 0700)
if err != nil && !os.IsExist(err) { if err != nil && !os.IsExist(err) {
log.Fatalf("Unable to create blob directory (%v): %v", blobDir, err) log.Fatalf("Unable to create blob directory (%v): %v", blobDir, err)
@ -66,10 +65,10 @@ func main() {
// These are used as the default certificate of all virtual servers // These are used as the default certificate of all virtual servers
// and the SSH admin console, but can be overridden using the "key" // and the SSH admin console, but can be overridden using the "key"
// and "cert" arguments to Grumble. // and "cert" arguments to Grumble.
certFn := filepath.Join(Args.DataDir, "cert.pem") certFn := filepath.Join(cmdArgs.DataDir, "cert.pem")
keyFn := filepath.Join(Args.DataDir, "key.pem") keyFn := filepath.Join(cmdArgs.DataDir, "key.pem")
shouldRegen := false shouldRegen := false
if Args.RegenKeys { if cmdArgs.RegenKeys {
shouldRegen = true shouldRegen = true
} else { } else {
// OK. Here's the idea: We check for the existence of the cert.pem // OK. Here's the idea: We check for the existence of the cert.pem
@ -100,7 +99,7 @@ func main() {
if shouldRegen { if shouldRegen {
log.Printf("Generating 4096-bit RSA keypair for self-signed certificate...") log.Printf("Generating 4096-bit RSA keypair for self-signed certificate...")
err := GenerateSelfSignedCert(certFn, keyFn) err := GenerateSelfSignedCert(certFn, keyFn, cmdArgs.DataDir)
if err != nil { if err != nil {
log.Printf("Error: %v", err) log.Printf("Error: %v", err)
return return
@ -111,8 +110,8 @@ func main() {
} }
// Should we import data from a Murmur SQLite file? // Should we import data from a Murmur SQLite file?
if SQLiteSupport && len(Args.SQLiteDB) > 0 { if SQLiteSupport && len(cmdArgs.SQLiteDB) > 0 {
f, err := os.Open(Args.DataDir) f, err := os.Open(cmdArgs.DataDir)
if err != nil { if err != nil {
log.Fatalf("Murmur import failed: %s", err.Error()) log.Fatalf("Murmur import failed: %s", err.Error())
} }
@ -123,20 +122,20 @@ func main() {
log.Fatalf("Murmur import failed: %s", err.Error()) log.Fatalf("Murmur import failed: %s", err.Error())
} }
if !Args.CleanUp && len(names) > 0 { if !cmdArgs.CleanUp && len(names) > 0 {
log.Fatalf("Non-empty datadir. Refusing to import Murmur data.") log.Fatalf("Non-empty datadir. Refusing to import Murmur data.")
} }
if Args.CleanUp { if cmdArgs.CleanUp {
log.Print("Cleaning up existing data directory") log.Print("Cleaning up existing data directory")
for _, name := range names { for _, name := range names {
if err := os.RemoveAll(filepath.Join(Args.DataDir, name)); err != nil { if err := os.RemoveAll(filepath.Join(cmdArgs.DataDir, name)); err != nil {
log.Fatalf("Unable to cleanup file: %s", name) log.Fatalf("Unable to cleanup file: %s", name)
} }
} }
} }
log.Printf("Importing Murmur data from '%s'", Args.SQLiteDB) log.Printf("Importing Murmur data from '%s'", cmdArgs.SQLiteDB)
if err = MurmurImport(Args.SQLiteDB); err != nil { if err = MurmurImport(cmdArgs.SQLiteDB, cmdArgs.DataDir); err != nil {
log.Fatalf("Murmur import failed: %s", err.Error()) log.Fatalf("Murmur import failed: %s", err.Error())
} }
@ -148,7 +147,7 @@ func main() {
// Create the servers directory if it doesn't already // Create the servers directory if it doesn't already
// exist. // exist.
serversDirPath := filepath.Join(Args.DataDir, "servers") serversDirPath := filepath.Join(cmdArgs.DataDir, "servers")
err = os.Mkdir(serversDirPath, 0700) err = os.Mkdir(serversDirPath, 0700)
if err != nil && !os.IsExist(err) { if err != nil && !os.IsExist(err) {
log.Fatalf("Unable to create servers directory: %v", err) log.Fatalf("Unable to create servers directory: %v", err)
@ -177,7 +176,7 @@ func main() {
for _, name := range names { for _, name := range names {
if matched, _ := regexp.MatchString("^[0-9]+$", name); matched { if matched, _ := regexp.MatchString("^[0-9]+$", name); matched {
log.Printf("Loading server %v", name) log.Printf("Loading server %v", name)
s, err := NewServerFromFrozen(name) s, err := NewServerFromFrozen(name, cmdArgs.DataDir)
if err != nil { if err != nil {
log.Fatalf("Unable to load server: %v", err.Error()) log.Fatalf("Unable to load server: %v", err.Error())
} }
@ -192,6 +191,7 @@ func main() {
// If no servers were found, create the default virtual server. // If no servers were found, create the default virtual server.
if len(servers) == 0 { if len(servers) == 0 {
s, err := NewServer(1) s, err := NewServer(1)
s.DataDir = cmdArgs.DataDir
if err != nil { if err != nil {
log.Fatalf("Couldn't start server: %s", err.Error()) log.Fatalf("Couldn't start server: %s", err.Error())
} }

View file

@ -39,7 +39,7 @@ const (
const SQLiteSupport = true const SQLiteSupport = true
// Import the structure of an existing Murmur SQLite database. // Import the structure of an existing Murmur SQLite database.
func MurmurImport(filename string) (err error) { func MurmurImport(filename string, dataDir string) (err error) {
db, err := sql.Open("sqlite", filename) db, err := sql.Open("sqlite", filename)
if err != nil { if err != nil {
panic(err.Error()) panic(err.Error())
@ -68,7 +68,7 @@ func MurmurImport(filename string) (err error) {
return err return err
} }
err = os.Mkdir(filepath.Join(Args.DataDir, strconv.FormatInt(sid, 10)), 0750) err = os.Mkdir(filepath.Join(dataDir, strconv.FormatInt(sid, 10)), 0750)
if err != nil { if err != nil {
return err return err
} }

View file

@ -122,6 +122,9 @@ type Server struct {
// Logging // Logging
*log.Logger *log.Logger
// Other configuration
DataDir string
} }
type clientLogForwarder struct { type clientLogForwarder struct {
@ -1425,8 +1428,8 @@ func (server *Server) Start() (err error) {
*/ */
// Wrap a TLS listener around the TCP connection // Wrap a TLS listener around the TCP connection
certFn := filepath.Join(Args.DataDir, "cert.pem") certFn := filepath.Join(server.DataDir, "cert.pem")
keyFn := filepath.Join(Args.DataDir, "key.pem") keyFn := filepath.Join(server.DataDir, "key.pem")
cert, err := tls.LoadX509KeyPair(certFn, keyFn) cert, err := tls.LoadX509KeyPair(certFn, keyFn)
if err != nil { if err != nil {
return err return err