1
0
Fork 0
forked from External/grumble

Refine auto-keypair regen functionality. Complain loudly if only one of (cert.pem, key.pem) exists.

This commit is contained in:
Mikkel Krautz 2011-11-10 13:17:18 +01:00
parent 76747e0b71
commit b94b04d2de
4 changed files with 67 additions and 42 deletions

View file

@ -55,7 +55,7 @@ func GenerateSelfSignedCert(certpath, keypath string) (err error) {
Bytes: keybuf,
}
certfn := filepath.Join(Args.DataDir, "cert")
certfn := filepath.Join(Args.DataDir, "cert.pem")
file, err := os.OpenFile(certfn, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0700)
if err != nil {
return err
@ -66,7 +66,7 @@ func GenerateSelfSignedCert(certpath, keypath string) (err error) {
return err
}
keyfn := filepath.Join(Args.DataDir, "key")
keyfn := filepath.Join(Args.DataDir, "key.pem")
file, err = os.OpenFile(keyfn, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0700)
if err != nil {
return err

View file

@ -76,23 +76,44 @@ func main() {
// These are used as the default certificate of all virtual servers
// and the SSH admin console, but can be overridden using the "key"
// and "cert" arguments to Grumble.
certFn := filepath.Join(Args.DataDir, "cert")
keyFn := filepath.Join(Args.DataDir, "key")
certFn := filepath.Join(Args.DataDir, "cert.pem")
keyFn := filepath.Join(Args.DataDir, "key.pem")
shouldRegen := false
if Args.RegenKeys {
shouldRegen = true
} else {
files := []string{certFn, keyFn}
for _, fn := range files {
_, err := os.Stat(fn)
if err != nil {
if e, ok := err.(*os.PathError); ok {
if e.Err == os.ENOENT {
shouldRegen = true
}
// OK. Here's the idea: We check for the existence of the cert.pem
// and key.pem files in the data directory on launch. Although these
// might be deleted later (and this check could be deemed useless),
// it's simply here to be convenient for admins.
hasKey := true
hasCert := true
_, err = os.Stat(certFn)
if err != nil {
if e, ok := err.(*os.PathError); ok {
if e.Err == os.ENOENT {
hasCert = false
}
}
}
_, err = os.Stat(keyFn)
if err != nil {
if e, ok := err.(*os.PathError); ok {
if e.Err == os.ENOENT {
hasKey = false
}
}
}
if !hasCert && !hasKey {
shouldRegen = true
} else if !hasCert || !hasKey {
if !hasCert {
log.Fatal("Grumble could not find its default certificate (cert.pem)")
}
if !hasKey {
log.Fatal("Grumble could not find its default private key (key.pem)")
}
}
}
if shouldRegen {
log.Printf("Generating 4096-bit RSA keypair for self-signed certificate...")
@ -143,6 +164,9 @@ func main() {
return
}
// Run the SSH admin console.
RunSSH()
// Read all entries of the data directory.
// We need these to load our virtual servers.
names, err := dataDir.Readdirnames(-1)
@ -189,9 +213,6 @@ func main() {
go s.ListenAndMurmur()
}
// Run the SSH admin console.
go RunSSH()
// If any servers were loaded, launch the signal
// handler goroutine and sleep...
if len(servers) > 0 {

View file

@ -1190,7 +1190,9 @@ func (s *Server) ListenAndMurmur() {
go s.SendUDP()
// Create a new listening TLS socket.
cert, err := tls.LoadX509KeyPair(filepath.Join(Args.DataDir, "cert"), filepath.Join(Args.DataDir, "key"))
certFn := filepath.Join(Args.DataDir, "cert.pem")
keyFn := filepath.Join(Args.DataDir, "key.pem")
cert, err := tls.LoadX509KeyPair(certFn, keyFn)
if err != nil {
s.Printf("Unable to load x509 key pair: %v", err)
return

54
ssh.go
View file

@ -72,7 +72,7 @@ func RunSSH() {
"<id> <key> <value>",
"Get a config value for the server with the given id")
pemBytes, err := ioutil.ReadFile(filepath.Join(Args.DataDir, "key"))
pemBytes, err := ioutil.ReadFile(filepath.Join(Args.DataDir, "key.pem"))
if err != nil {
log.Fatal(err)
}
@ -89,32 +89,34 @@ func RunSSH() {
log.Printf("Listening for SSH connections on '%v'", Args.SshAddr)
for {
conn, err := listener.Accept()
if err != nil {
log.Fatalf("ssh: unable to accept incoming connection: %v", err)
}
err = conn.Handshake()
if err == io.EOF {
continue
} else if err != nil {
log.Fatalf("ssh: unable to perform handshake: %v", err)
}
go func() {
for {
channel, err := conn.Accept()
if err == io.EOF {
return
} else if err != nil {
log.Fatalf("ssh: unable to accept channel: %v", err)
}
go handleChannel(channel)
go func() {
for {
conn, err := listener.Accept()
if err != nil {
log.Fatalf("ssh: unable to accept incoming connection: %v", err)
}
}()
}
err = conn.Handshake()
if err == io.EOF {
continue
} else if err != nil {
log.Fatalf("ssh: unable to perform handshake: %v", err)
}
go func() {
for {
channel, err := conn.Accept()
if err == io.EOF {
return
} else if err != nil {
log.Fatalf("ssh: unable to accept channel: %v", err)
}
go handleChannel(channel)
}
}()
}
}()
}
func handleChannel(channel ssh.Channel) {