From 8db9a7043ac8a345fe419316ff9844eacd75ce0d Mon Sep 17 00:00:00 2001 From: Mikkel Krautz Date: Sat, 20 Nov 2010 15:04:38 +0100 Subject: [PATCH] Disconnect clients properly. --- client.go | 6 +++--- server.go | 31 ++++++++++++++++++++++++------- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/client.go b/client.go index a3d6b3e..2acc8c5 100644 --- a/client.go +++ b/client.go @@ -44,15 +44,15 @@ type Client struct { // Something invalid happened on the wire. func (client *Client) Panic(reason string) { - client.disconnected = true - // fixme(mkrautz): we should inform the server "handler" method through a channel of this event, - // so it can perform a proper disconnect. + client.Disconnect() } func (client *Client) Disconnect() { client.disconnected = true close(client.udprecv) close(client.msgchan) + + client.server.RemoveClient(client) } // Read a protobuf message from a client diff --git a/server.go b/server.go index 8ea0abb..5edb732 100644 --- a/server.go +++ b/server.go @@ -53,7 +53,7 @@ type Server struct { session uint32 clients map[uint32]*Client - hmutex *sync.RWMutex + hmutex sync.Mutex hclients map[string][]*Client hpclients map[string]*Client @@ -84,7 +84,6 @@ func NewServer(addr string, port int) (s *Server, err os.Error) { s.clients = make(map[uint32]*Client) - s.hmutex = new(sync.RWMutex) s.hclients = make(map[string][]*Client) s.hpclients = make(map[string]*Client) @@ -132,6 +131,27 @@ func (server *Server) NewClient(conn net.Conn) (err os.Error) { return } +// Remove a disconnected client from the server's +// internal representation. +func (server *Server) RemoveClient(client *Client) { + server.hmutex.Lock() + defer server.hmutex.Unlock() + + if client.udpaddr != nil { + host := client.udpaddr.IP.String() + oldclients := server.hclients[host] + newclients := []*Client{} + for _, hostclient := range oldclients { + if hostclient != client { + newclients = append(newclients, hostclient) + } + } + server.hclients[host] = newclients + server.hpclients[client.udpaddr.String()] = nil, false + } + server.clients[client.Session] = nil, false +} + // This is the synchronous handler goroutine. // Important control channel messages are routed through this Goroutine // to keep server state synchronized. @@ -198,7 +218,6 @@ func (server *Server) handleAuthenticate(client *Client, msg *Message) { server.updateCodecVersions() client.sendChannelList() - client.state = StateClientAuthenticated // Add the client to the connected list @@ -299,7 +318,6 @@ func (server *Server) updateCodecVersions() { } log.Printf("CELT codec switch %v %v (PreferAlpha %v)", server.AlphaCodec, server.BetaCodec, server.PreferAlphaCodec) - return } @@ -467,8 +485,7 @@ func (server *Server) ListenUDP() { // // If we don't find any matches, we look in the 'hclients', // which maps a host address to a slice of clients. - server.hmutex.RLock() - defer server.hmutex.RUnlock() + server.hmutex.Lock() client, ok := server.hpclients[udpaddr.String()] if ok { err = client.crypt.Decrypt(buf[0:nread], plain[0:]) @@ -478,7 +495,6 @@ func (server *Server) ListenUDP() { match = client } else { host := udpaddr.IP.String() - server.hmutex.RLock() hostclients := server.hclients[host] for _, client := range hostclients { err = client.crypt.Decrypt(buf[0:nread], plain[0:]) @@ -489,6 +505,7 @@ func (server *Server) ListenUDP() { } } } + server.hmutex.Unlock() // No client found. if match == nil {