diff --git a/client.go b/client.go index 71e6274..a6b3a2e 100644 --- a/client.go +++ b/client.go @@ -33,7 +33,7 @@ type Client struct { state int server *Server - udprecv chan []byte + udprecv chan []byte disconnected bool @@ -160,7 +160,7 @@ func (client *Client) disconnect(kicked bool) { } } -// Disconnect a client (client disconnected) +// Disconnect a client (client requested or server shutdown) func (client *Client) Disconnect() { client.disconnect(false) } @@ -352,7 +352,7 @@ func (client *Client) sendMessage(msg interface{}) error { kind = mumbleproto.MessageType(msg) if kind == mumbleproto.MessageUDPTunnel { msgData = msg.([]byte) - } else { + } else { msgData, err = proto.Marshal(msg) if err != nil { return err diff --git a/grumble.go b/grumble.go index 6cb6646..a1123f6 100644 --- a/grumble.go +++ b/grumble.go @@ -59,7 +59,7 @@ func main() { exists := false if e, ok := err.(*os.PathError); ok { if e.Err == os.EEXIST { - exists = true + exists = true } } if !exists { @@ -195,7 +195,10 @@ func main() { log.Fatalf("Unable to freeze server to disk: %v", err.Error()) } servers[s.Id] = s - go s.ListenAndMurmur() + err = s.Start() + if err != nil { + log.Printf("Unable to start server %v: %v", s.Id, err.Error()) + } } } @@ -207,10 +210,15 @@ func main() { } servers[s.Id] = s - os.Mkdir(filepath.Join(Args.DataDir, fmt.Sprintf("%v", 1)), 0750) - s.FreezeToFile() - go s.ListenAndMurmur() + err = s.FreezeToFile() + if err != nil { + log.Fatalf("Unable to freeze newly created server to disk: %v", err.Error()) + } + err = s.Start() + if err != nil { + log.Fatal("Unable to start newly created server: %v", err.Error()) + } } // If any servers were loaded, launch the signal @@ -219,4 +227,4 @@ func main() { go SignalHandler() select {} } -} \ No newline at end of file +} diff --git a/message.go b/message.go index 04d46e7..6b56416 100644 --- a/message.go +++ b/message.go @@ -18,18 +18,9 @@ import ( ) type Message struct { - buf []byte - - // Kind denotes a message kind for TCP packets. This field - // is ignored for UDP packets. - kind uint16 - - // For UDP datagrams one of these fields have to be filled out. - // If there is no connection established, address must be used. - // If the datagram comes from an already-connected client, the - // client field should point to that client. - client *Client - address net.Addr + buf []byte + kind uint16 + client *Client } type VoiceBroadcast struct { diff --git a/register.go b/register.go index db8b1b5..8d59218 100644 --- a/register.go +++ b/register.go @@ -31,7 +31,7 @@ type Register struct { Location string `xml:"location"` } -const registerUrl = "https://mumble.hive.no/register.cgi" +const registerUrl = "https://mumble.hive.no/register.cgi" // Determines whether a server is public by checking whether the // config values required for public registration are set. @@ -102,7 +102,7 @@ func (server *Server) RegisterPublicServer() { // Post registration XML data to server asynchronously in its own goroutine go func() { tr := &http.Transport{ - TLSClientConfig: config, + TLSClientConfig: config, } client := &http.Client{Transport: tr} r, err := client.Post(registerUrl, "text/xml", ioutil.NopCloser(buf)) diff --git a/server.go b/server.go index 8a5339f..d619810 100644 --- a/server.go +++ b/server.go @@ -55,13 +55,17 @@ type KeyValuePair struct { // A Murmur server instance type Server struct { - Id int64 - listener tls.Listener - address string - port int - udpconn *net.UDPConn - tlscfg *tls.Config - running bool + Id int64 + + tcpl *net.TCPListener + tlsl *tls.Listener + address string + port int + udpconn *net.UDPConn + tlscfg *tls.Config + bye chan bool + netwg sync.WaitGroup + running bool incoming chan *Message voicebroadcast chan *VoiceBroadcast @@ -148,6 +152,7 @@ func NewServer(id int64, addr string, port int) (s *Server, err error) { s.hclients = make(map[string][]*Client) s.hpclients = make(map[string]*Client) + s.bye = make(chan bool) s.incoming = make(chan *Message) s.voicebroadcast = make(chan *VoiceBroadcast) s.cfgUpdate = make(chan *KeyValuePair) @@ -345,10 +350,13 @@ func (server *Server) UnlinkChannels(channel *Channel, other *Channel) { // This is the synchronous handler goroutine. // Important control channel messages are routed through this Goroutine // to keep server state synchronized. -func (server *Server) handler() { +func (server *Server) handlerLoop() { regtick := time.Tick((3600 + ((server.Id * 60) % 600)) * 1e9) for { select { + // We're done. Stop the server's event handler + case <-server.bye: + return // Control channel messages case msg := <-server.incoming: client := msg.client @@ -863,19 +871,6 @@ func (server *Server) handleIncomingMessage(client *Client, msg *Message) { } } -func (s *Server) SetupUDP() (err error) { - addr := &net.UDPAddr{ - net.ParseIP(s.address), - s.port, - } - s.udpconn, err = net.ListenUDP("udp", addr) - if err != nil { - return - } - - return -} - // Send the content of buf as a UDP packet to addr. func (s *Server) SendUDP(buf []byte, addr *net.UDPAddr) (err error) { _, err = s.udpconn.WriteTo(buf, addr) @@ -883,13 +878,18 @@ func (s *Server) SendUDP(buf []byte, addr *net.UDPAddr) (err error) { } // Listen for and handle UDP packets. -func (server *Server) ListenUDP() { +func (server *Server) udpListenLoop() { + defer server.netwg.Done() + buf := make([]byte, UDPPacketSize) for { nread, remote, err := server.udpconn.ReadFrom(buf) if err != nil { - // Not much to do here. This is bad, of course. Should we panic this server instance? - continue + if isTimeout(err) { + continue + } else { + return + } } udpaddr, ok := remote.(*net.UDPAddr) @@ -917,7 +917,7 @@ func (server *Server) ListenUDP() { err = server.SendUDP(buffer.Bytes(), udpaddr) if err != nil { - server.Print("Unable to write UDP packet: %v", err.Error()) + return } } else { @@ -1153,93 +1153,180 @@ func (server *Server) FilterText(text string) (filtered string, err error) { } // The accept loop of the server. -func (s *Server) ListenAndMurmur() { - // Launch the event handler goroutine - go s.handler() +func (server *Server) acceptLoop() { + defer server.netwg.Done() - host := s.cfg.StringValue("Address") - if host != "" { - s.address = host - } - port := s.cfg.IntValue("Port") - if port != 0 { - s.port = port - } - - s.running = true - - // Setup our UDP listener and spawn our reader and writer goroutines - s.SetupUDP() - go s.ListenUDP() - - // Create a new listening TLS socket. - certFn := filepath.Join(Args.DataDir, "cert.pem") - keyFn := filepath.Join(Args.DataDir, "key.pem") - cert, err := tls.LoadX509KeyPair(certFn, keyFn) - if err != nil { - s.Printf("Unable to load x509 key pair: %v", err) - return - } - - cfg := new(tls.Config) - cfg.Certificates = append(cfg.Certificates, cert) - cfg.AuthenticateClient = true - s.tlscfg = cfg - - tl, err := net.ListenTCP("tcp", &net.TCPAddr{ - net.ParseIP(s.address), - s.port, - }) - if err != nil { - s.Printf("Cannot bind: %s\n", err) - return - } - - listener := tls.NewListener(tl, s.tlscfg) - - s.Printf("Started: listening on %v", tl.Addr()) - - // Open a fresh freezer log - err = s.openFreezeLog() - if err != nil { - s.Fatal(err) - } - - // Update server registration if needed. - go func() { - time.Sleep((60 + s.Id*10) * 1e9) - s.RegisterPublicServer() - }() - - // The main accept loop. Basically, we block - // until we get a new client connection, and - // when we do get a new connection, we spawn - // a new Go-routine to handle the client. for { // New client connected - conn, err := listener.Accept() + conn, err := server.tlsl.Accept() if err != nil { - s.Printf("Unable to accept new client: %v", err) - continue + if isTimeout(err) { + continue + } else { + return + } } // Is the client banned? // fixme(mkrautz): Clean up expired bans - if s.IsBanned(conn) { - s.Printf("Rejected client %v: Banned", conn.RemoteAddr()) + if server.IsBanned(conn) { + server.Printf("Rejected client %v: Banned", conn.RemoteAddr()) err := conn.Close() if err != nil { - s.Printf("Unable to close connection: %v", err) + server.Printf("Unable to close connection: %v", err) } continue } // Create a new client connection from our *tls.Conn // which wraps net.TCPConn. - err = s.NewClient(conn) + err = server.NewClient(conn) if err != nil { - s.Printf("Unable to handle new client: %v", err) + server.Printf("Unable to handle new client: %v", err) continue } } } + +// The isTimeout function checks whether a +// network error is a timeout. +func isTimeout(err error) bool { + if e, ok := err.(net.Error); ok { + return e.Timeout() + } + return false +} + +// Start the server. +func (server *Server) Start() (err error) { + if server.running { + return errors.New("already running") + } + + host := server.cfg.StringValue("Address") + if host != "" { + server.address = host + } + port := server.cfg.IntValue("Port") + if port != 0 { + server.port = port + } + + // Setup our UDP listener + addr := &net.UDPAddr{ + net.ParseIP(server.address), + server.port, + } + server.udpconn, err = net.ListenUDP("udp", addr) + if err != nil { + return err + } + err = server.udpconn.SetReadTimeout(1e9) + if err != nil { + return err + } + + // Set up our TCP connection + server.tcpl, err = net.ListenTCP("tcp", &net.TCPAddr{ + net.ParseIP(server.address), + server.port, + }) + if err != nil { + return err + } + err = server.tcpl.SetTimeout(1e9) + if err != nil { + return err + } + + // Wrap a TLS listener around the TCP connection + certFn := filepath.Join(Args.DataDir, "cert.pem") + keyFn := filepath.Join(Args.DataDir, "key.pem") + cert, err := tls.LoadX509KeyPair(certFn, keyFn) + if err != nil { + return err + } + server.tlscfg = &tls.Config{ + Certificates: []tls.Certificate{cert}, + AuthenticateClient: true, + } + server.tlsl = tls.NewListener(server.tcpl, server.tlscfg) + + server.Printf("Started: listening on %v", server.tcpl.Addr()) + server.running = true + + // Open a fresh freezer log + err = server.openFreezeLog() + if err != nil { + server.Fatal(err) + } + + // Launch the event handler goroutine + go server.handlerLoop() + + // Add the two network receiver goroutines to the net waitgroup + // and launch them. + // + // We use the waitgroup to provide a blocking Stop() method + // for the servers. Each network goroutine defers a call to + // netwg.Done(). In the Stop() we close all the connections + // and call netwg.Wait() to wait for the goroutines to end. + server.netwg.Add(2) + go server.udpListenLoop() + go server.acceptLoop() + + // Schedule a server registration update (if needed) + go func() { + time.Sleep((60 + server.Id*10) * 1e9) + server.RegisterPublicServer() + }() + + return nil +} + +// Stop the server. +func (server *Server) Stop() (err error) { + if !server.running { + return errors.New("server not running") + } + + // Stop the handler goroutine and disconnect all + // clients + server.bye <- true + for _, client := range server.clients { + client.Disconnect() + } + + // Close the TLS listener and the TCP listener + err = server.tlsl.Close() + if err != nil { + return err + } + err = server.tcpl.Close() + if err != nil { + return err + } + + // Close the UDP connection + err = server.udpconn.Close() + if err != nil { + return err + } + + // Since we'll (on some OSes) have to wait for the network + // goroutines to end, we might as well use the time to store + // a full server freeze to disk. + err = server.FreezeToFile() + if err != nil { + server.Fatal(err) + } + + // Wait for the two network receiver + // goroutines end. + server.netwg.Wait() + + server.running = false + server.Printf("Stopped") + + return nil +} diff --git a/signal_unix.go b/signal_unix.go index c737fef..dc54646 100644 --- a/signal_unix.go +++ b/signal_unix.go @@ -22,7 +22,7 @@ func SignalHandler() { } continue } - + if sig == os.SIGINT || sig == os.SIGTERM { os.Exit(0) } diff --git a/ssh.go b/ssh.go index e808c64..c88f7f6 100644 --- a/ssh.go +++ b/ssh.go @@ -227,9 +227,14 @@ func StartServerCmd(reply SshCmdReply, args []string) error { return errors.New("no such server") } - _ = server + err = server.Start() + if err != nil { + return fmt.Errorf("unable to start: %v", err.Error()) + } - return errors.New("not implemented") + reply.WriteString(fmt.Sprintf("[%v] Started\r\n", serverId)) + + return nil } func StopServerCmd(reply SshCmdReply, args []string) error { @@ -247,9 +252,14 @@ func StopServerCmd(reply SshCmdReply, args []string) error { return errors.New("no such server") } - _ = server + err = server.Stop() + if err != nil { + return fmt.Errorf("unable to stop: %v", err.Error()) + } - return errors.New("not implemented") + reply.WriteString(fmt.Sprintf("[%v] Stopped\r\n", serverId)) + + return nil } func SetSuperUserPasswordCmd(reply SshCmdReply, args []string) error { @@ -291,7 +301,7 @@ func SetConfCmd(reply SshCmdReply, args []string) error { key := args[2] value := args[3] server.cfg.Set(key, value) - server.cfgUpdate <- &KeyValuePair{Key: key, Value: value} + server.cfgUpdate <- &KeyValuePair{Key: key, Value: value} reply.WriteString(fmt.Sprintf("[%v] %v = %v\r\n", serverId, key, value)) return nil