diff --git a/gencert.go b/gencert.go index 43f8feb..926a56c 100644 --- a/gencert.go +++ b/gencert.go @@ -55,7 +55,7 @@ func GenerateSelfSignedCert(certpath, keypath string) (err error) { Bytes: keybuf, } - certfn := filepath.Join(Args.DataDir, "cert") + certfn := filepath.Join(Args.DataDir, "cert.pem") file, err := os.OpenFile(certfn, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0700) if err != nil { return err @@ -66,7 +66,7 @@ func GenerateSelfSignedCert(certpath, keypath string) (err error) { return err } - keyfn := filepath.Join(Args.DataDir, "key") + keyfn := filepath.Join(Args.DataDir, "key.pem") file, err = os.OpenFile(keyfn, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0700) if err != nil { return err diff --git a/grumble.go b/grumble.go index 32245c1..6cb6646 100644 --- a/grumble.go +++ b/grumble.go @@ -76,23 +76,44 @@ func main() { // These are used as the default certificate of all virtual servers // and the SSH admin console, but can be overridden using the "key" // and "cert" arguments to Grumble. - certFn := filepath.Join(Args.DataDir, "cert") - keyFn := filepath.Join(Args.DataDir, "key") + certFn := filepath.Join(Args.DataDir, "cert.pem") + keyFn := filepath.Join(Args.DataDir, "key.pem") shouldRegen := false if Args.RegenKeys { shouldRegen = true } else { - files := []string{certFn, keyFn} - for _, fn := range files { - _, err := os.Stat(fn) - if err != nil { - if e, ok := err.(*os.PathError); ok { - if e.Err == os.ENOENT { - shouldRegen = true - } + // OK. Here's the idea: We check for the existence of the cert.pem + // and key.pem files in the data directory on launch. Although these + // might be deleted later (and this check could be deemed useless), + // it's simply here to be convenient for admins. + hasKey := true + hasCert := true + _, err = os.Stat(certFn) + if err != nil { + if e, ok := err.(*os.PathError); ok { + if e.Err == os.ENOENT { + hasCert = false } } } + _, err = os.Stat(keyFn) + if err != nil { + if e, ok := err.(*os.PathError); ok { + if e.Err == os.ENOENT { + hasKey = false + } + } + } + if !hasCert && !hasKey { + shouldRegen = true + } else if !hasCert || !hasKey { + if !hasCert { + log.Fatal("Grumble could not find its default certificate (cert.pem)") + } + if !hasKey { + log.Fatal("Grumble could not find its default private key (key.pem)") + } + } } if shouldRegen { log.Printf("Generating 4096-bit RSA keypair for self-signed certificate...") @@ -143,6 +164,9 @@ func main() { return } + // Run the SSH admin console. + RunSSH() + // Read all entries of the data directory. // We need these to load our virtual servers. names, err := dataDir.Readdirnames(-1) @@ -189,9 +213,6 @@ func main() { go s.ListenAndMurmur() } - // Run the SSH admin console. - go RunSSH() - // If any servers were loaded, launch the signal // handler goroutine and sleep... if len(servers) > 0 { diff --git a/server.go b/server.go index dd2b05e..39a4e35 100644 --- a/server.go +++ b/server.go @@ -1190,7 +1190,9 @@ func (s *Server) ListenAndMurmur() { go s.SendUDP() // Create a new listening TLS socket. - cert, err := tls.LoadX509KeyPair(filepath.Join(Args.DataDir, "cert"), filepath.Join(Args.DataDir, "key")) + certFn := filepath.Join(Args.DataDir, "cert.pem") + keyFn := filepath.Join(Args.DataDir, "key.pem") + cert, err := tls.LoadX509KeyPair(certFn, keyFn) if err != nil { s.Printf("Unable to load x509 key pair: %v", err) return diff --git a/ssh.go b/ssh.go index fc7ca7f..1d4ce9b 100644 --- a/ssh.go +++ b/ssh.go @@ -72,7 +72,7 @@ func RunSSH() { " ", "Get a config value for the server with the given id") - pemBytes, err := ioutil.ReadFile(filepath.Join(Args.DataDir, "key")) + pemBytes, err := ioutil.ReadFile(filepath.Join(Args.DataDir, "key.pem")) if err != nil { log.Fatal(err) } @@ -89,32 +89,34 @@ func RunSSH() { log.Printf("Listening for SSH connections on '%v'", Args.SshAddr) - for { - conn, err := listener.Accept() - if err != nil { - log.Fatalf("ssh: unable to accept incoming connection: %v", err) - } - - err = conn.Handshake() - if err == io.EOF { - continue - } else if err != nil { - log.Fatalf("ssh: unable to perform handshake: %v", err) - } - - go func() { - for { - channel, err := conn.Accept() - if err == io.EOF { - return - } else if err != nil { - log.Fatalf("ssh: unable to accept channel: %v", err) - } - - go handleChannel(channel) + go func() { + for { + conn, err := listener.Accept() + if err != nil { + log.Fatalf("ssh: unable to accept incoming connection: %v", err) } - }() - } + + err = conn.Handshake() + if err == io.EOF { + continue + } else if err != nil { + log.Fatalf("ssh: unable to perform handshake: %v", err) + } + + go func() { + for { + channel, err := conn.Accept() + if err == io.EOF { + return + } else if err != nil { + log.Fatalf("ssh: unable to accept channel: %v", err) + } + + go handleChannel(channel) + } + }() + } + }() } func handleChannel(channel ssh.Channel) {