diff --git a/client.go b/client.go index 84d5cdd..7e843fa 100644 --- a/client.go +++ b/client.go @@ -49,6 +49,12 @@ func (client *Client) Panic(reason string) { // so it can perform a proper disconnect. } +func (client *Client) Disconnect() { + client.disconnected = true + close(client.udprecv) + close(client.msgchan) +} + // Read a protobuf message from a client func (client *Client) readProtoMessage() (msg *Message, err os.Error) { var length uint32 @@ -101,8 +107,13 @@ func (c *Client) sendProtoMessage(kind uint16, msg interface{}) (err os.Error) { // UDP receiver. func (client *Client) udpreceiver() { - for { - buf := <-client.udprecv + for buf := range client.udprecv { + + // Channel close. + if len(buf) == 0 { + return + } + kind := (buf[0] >> 5) & 0x07; switch kind { @@ -162,8 +173,11 @@ func (client *Client) sendUdp(msg *Message) { // Sender Goroutine // func (client *Client) sender() { - for { - msg := <-client.msgchan + for msg := range client.msgchan { + // Check for channel close. + if len(msg.buf) == 0 { + return + } // First, we write out the message type as a big-endian uint16 err := binary.Write(client.writer, binary.BigEndian, msg.kind) @@ -198,12 +212,17 @@ func (client *Client) sender() { // Receiver Goroutine func (client *Client) receiver() { for { - // The version handshake is done. Forward this message to the synchronous request handler. if client.state == StateClientAuthenticated || client.state == StateClientSentVersion { // Try to read the next message in the pool msg, err := client.readProtoMessage() if err != nil { + if err == os.EOF { + log.Printf("Client disconnected.") + client.Disconnect() + } else { + log.Printf("Client error.") + } return } // Special case UDPTunnel messages. They're high priority and shouldn't @@ -231,6 +250,12 @@ func (client *Client) receiver() { } else if client.state == StateServerSentVersion { msg, err := client.readProtoMessage() if err != nil { + if err == os.EOF { + log.Printf("Client disconnected.") + client.Disconnect() + } else { + log.Printf("Client error.") + } return } @@ -268,18 +293,14 @@ func (client *Client) sendChannelList() { // Send the userlist to a client. func (client *Client) sendUserList() { server := client.server - - server.cmutex.RLock() - defer server.cmutex.RUnlock() - - for _, user := range server.clients { - err := user.sendProtoMessage(MessageUserState, &mumbleproto.UserState{ + for _, client := range server.clients { + err := client.sendProtoMessage(MessageUserState, &mumbleproto.UserState{ Session: proto.Uint32(client.Session), Name: proto.String(client.Username), ChannelId: proto.Uint32(0), }) if err != nil { - log.Printf("unable to send!") + log.Printf("Unable to send UserList") continue } } diff --git a/message.go b/message.go index 636b65c..26ab65b 100644 --- a/message.go +++ b/message.go @@ -57,7 +57,7 @@ type Message struct { // is ignored for UDP packets. kind uint16 - // For UDP datagrams one of these fiels have to be filled out. + // For UDP datagrams one of these fields have to be filled out. // If there is no connection established, address must be used. // If the datagram comes from an already-connected client, the // client field should point to that client. @@ -143,7 +143,10 @@ func (server *Server) handleTextMessage(client *Client, msg *Message) { users := []*Client{}; for i := 0; i < len(txtmsg.Session); i++ { - user := server.getClient(txtmsg.Session[i]) + user, ok := server.clients[txtmsg.Session[i]] + if !ok { + log.Panic("Could not look up client by session") + } users = append(users, user) } @@ -170,4 +173,5 @@ func (server *Server) handleUserStatsMessage(client *Client, msg *Message) { if err != nil { client.Panic(err.String()) } + log.Printf("UserStatsMessage") } diff --git a/server.go b/server.go index bf55935..9b3418a 100644 --- a/server.go +++ b/server.go @@ -43,18 +43,19 @@ type Server struct { incoming chan *Message outgoing chan *Message - udpsend chan *Message // Config-related MaxUsers int MaxBandwidth uint32 + // Clients session uint32 + clients map[uint32]*Client - // A list of all connected clients - cmutex *sync.RWMutex - clients []*Client + hmutex *sync.RWMutex + hclients map[string][]*Client + hpclients map[string]*Client // Codec information AlphaCodec int32 @@ -81,8 +82,11 @@ func NewServer(addr string, port int) (s *Server, err os.Error) { s.address = addr s.port = port - // Create the list of connected clients - s.cmutex = new(sync.RWMutex) + s.clients = make(map[uint32]*Client) + + s.hmutex = new(sync.RWMutex) + s.hclients = make(map[string][]*Client) + s.hpclients = make(map[string]*Client) s.outgoing = make(chan *Message) s.incoming = make(chan *Message) @@ -91,8 +95,6 @@ func NewServer(addr string, port int) (s *Server, err os.Error) { s.MaxBandwidth = 300000 s.MaxUsers = 10 - // Allocate the root channel - s.root = &Channel{ Id: 0, Name: "Root", @@ -107,13 +109,13 @@ func NewServer(addr string, port int) (s *Server, err os.Error) { // Called by the server to initiate a new client connection. func (server *Server) NewClient(conn net.Conn) (err os.Error) { client := new(Client) - - // Get the address of the connected client - if addr := conn.RemoteAddr(); addr != nil { - client.tcpaddr = addr.(*net.TCPAddr) - log.Printf("client connected: %s", client.tcpaddr.String()) + addr := conn.RemoteAddr() + if addr == nil { + err = os.NewError("Unable to extract address for client.") + return } + client.tcpaddr = addr.(*net.TCPAddr) client.server = server client.conn = conn client.reader = bufio.NewReader(client.conn) @@ -123,15 +125,6 @@ func (server *Server) NewClient(conn net.Conn) (err os.Error) { client.msgchan = make(chan *Message) client.udprecv = make(chan []byte) - // New client connection.... - server.session += 1 - client.Session = server.session - - // Add it to the list of connected clients - server.cmutex.Lock() - server.clients = append(server.clients, client) - server.cmutex.Unlock() - go client.receiver() go client.udpreceiver() go client.sender() @@ -139,21 +132,9 @@ func (server *Server) NewClient(conn net.Conn) (err os.Error) { return } -// Lookup a client by it's session id. Optimize this by using a map. -func (server *Server) getClient(session uint32) (client *Client) { - server.cmutex.RLock() - defer server.cmutex.RUnlock() - - for _, user := range server.clients { - if user.Session == session { - return user - } - } - - return nil -} - -// This is the synchronous request handler for all incoming messages. +// This is the synchronous handler goroutine. +// Important control channel messages are routed through this Goroutine +// to keep server state synchronized. func (server *Server) handler() { for { msg := <-server.incoming @@ -167,6 +148,7 @@ func (server *Server) handler() { } } +// Handle a Authenticate protobuf message. func (server *Server) handleAuthenticate(client *Client, msg *Message) { // Is this message not an authenticate message? If not, discard it... if msg.kind != MessageAuthenticate { @@ -201,8 +183,8 @@ func (server *Server) handleAuthenticate(client *Client, msg *Message) { return } - // Send CryptState information to the client so it can establish an UDP connection - // (if it wishes)... + // Send CryptState information to the client so it can establish an UDP connection, + // if it wishes. err = client.sendProtoMessage(MessageCryptSetup, &mumbleproto.CryptSetup{ Key: client.crypt.RawKey[0:], ClientNonce: client.crypt.DecryptIV[0:], @@ -219,7 +201,18 @@ func (server *Server) handleAuthenticate(client *Client, msg *Message) { client.state = StateClientAuthenticated - // Broadcast that we, the client, entered a channel... + // Add the client to the connected list + server.session += 1 + client.Session = server.session + server.clients[client.Session] = client + + // Add the client to the host slice for its host address. + host := client.tcpaddr.IP.String() + server.hmutex.Lock() + server.hclients[host] = append(server.hclients[host], client) + server.hmutex.Unlock() + + // Broadcast the the user entered a channel err = server.broadcastProtoMessage(MessageUserState, &mumbleproto.UserState{ Session: proto.Uint32(client.Session), Name: proto.String(client.Username), @@ -258,9 +251,6 @@ func (server *Server) updateCodecVersions() { var winner int32 var count int - server.cmutex.RLock() - defer server.cmutex.RUnlock() - for _, client := range server.clients { for i := 0; i < len(client.codecs); i++ { codecusers[client.codecs[i]] += 1 @@ -314,9 +304,6 @@ func (server *Server) updateCodecVersions() { } func (server *Server) sendUserList(client *Client) { - server.cmutex.RLock() - defer server.cmutex.RUnlock() - for _, user := range server.clients { if user.state != StateClientAuthenticated { continue @@ -339,9 +326,6 @@ func (server *Server) sendUserList(client *Client) { } func (server *Server) broadcastProtoMessage(kind uint16, msg interface{}) (err os.Error) { - server.cmutex.RLock() - defer server.cmutex.RUnlock() - for _, client := range server.clients { if client.state != StateClientAuthenticated { continue @@ -476,55 +460,45 @@ func (server *Server) ListenUDP() { } else { var match *Client plain := make([]byte, nread-4) - decrypted := false - // First, check if any of our clients match the net.UDPAddr... - server.cmutex.RLock() - for _, client := range server.clients { - if client.udpaddr.String() == udpaddr.String() { - match = client + // Determine which client sent the the packet. First, we + // check the map 'hpclients' in the server struct. It maps + // a hort-post combination to a client. + // + // If we don't find any matches, we look in the 'hclients', + // which maps a host address to a slice of clients. + server.hmutex.RLock() + defer server.hmutex.RUnlock() + client, ok := server.hpclients[udpaddr.String()] + if ok { + err = client.crypt.Decrypt(buf[0:nread], plain[0:]) + if err != nil { + log.Panicf("Unable to decrypt incoming packet for client %v (host-port matched)", client) } - } - server.cmutex.RUnlock() - - // No matching client found. We must try to decrypt... - if match == nil { - server.cmutex.RLock() - for _, client := range server.clients { - // Try to decrypt. + match = client + } else { + host := udpaddr.IP.String() + server.hmutex.RLock() + hostclients := server.hclients[host] + for _, client := range hostclients { err = client.crypt.Decrypt(buf[0:nread], plain[0:]) if err != nil { - // Decryption failed. Try another client... continue + } else { + match = client } - - // Decryption succeeded. - decrypted = true - - // If we were able to successfully decrpyt, add - // the UDPAddr to the Client struct. - log.Printf("Client UDP connection established.") - client.udpaddr = remote.(*net.UDPAddr) - match = client - - break } - server.cmutex.RUnlock() } - // We were not able to find a client that could decrypt the incoming - // packet. Log it? + // No client found. if match == nil { + log.Printf("No match found for packet. Discarding...") continue } - if !decrypted { - err = match.crypt.Decrypt(buf[0:nread], plain[0:]) - if err != nil { - log.Printf("Unable to decrypt from client..") - } + if match.udpaddr == nil { + match.udpaddr = udpaddr } - match.udp = true match.udprecv <- plain } @@ -553,16 +527,15 @@ func (s *Server) ListenAndMurmur() { // when we do get a new connection, we spawn // a new Go-routine to handle the client. for { - // New client connected conn, err := l.Accept() if err != nil { - log.Printf("unable to accept()") + log.Printf("Unable to accept() new client.") } tls, ok := conn.(*tls.Conn) if !ok { - log.Printf("Not tls :(") + log.Panic("Internal inconsistency error.") } // Force the TLS handshake to get going. We'd like