Disconnect clients properly.

This commit is contained in:
Mikkel Krautz 2010-11-20 15:04:38 +01:00
parent c2f3f0de47
commit 8db9a7043a
2 changed files with 27 additions and 10 deletions

View file

@ -44,15 +44,15 @@ type Client struct {
// Something invalid happened on the wire. // Something invalid happened on the wire.
func (client *Client) Panic(reason string) { func (client *Client) Panic(reason string) {
client.disconnected = true client.Disconnect()
// fixme(mkrautz): we should inform the server "handler" method through a channel of this event,
// so it can perform a proper disconnect.
} }
func (client *Client) Disconnect() { func (client *Client) Disconnect() {
client.disconnected = true client.disconnected = true
close(client.udprecv) close(client.udprecv)
close(client.msgchan) close(client.msgchan)
client.server.RemoveClient(client)
} }
// Read a protobuf message from a client // Read a protobuf message from a client

View file

@ -53,7 +53,7 @@ type Server struct {
session uint32 session uint32
clients map[uint32]*Client clients map[uint32]*Client
hmutex *sync.RWMutex hmutex sync.Mutex
hclients map[string][]*Client hclients map[string][]*Client
hpclients map[string]*Client hpclients map[string]*Client
@ -84,7 +84,6 @@ func NewServer(addr string, port int) (s *Server, err os.Error) {
s.clients = make(map[uint32]*Client) s.clients = make(map[uint32]*Client)
s.hmutex = new(sync.RWMutex)
s.hclients = make(map[string][]*Client) s.hclients = make(map[string][]*Client)
s.hpclients = make(map[string]*Client) s.hpclients = make(map[string]*Client)
@ -132,6 +131,27 @@ func (server *Server) NewClient(conn net.Conn) (err os.Error) {
return return
} }
// Remove a disconnected client from the server's
// internal representation.
func (server *Server) RemoveClient(client *Client) {
server.hmutex.Lock()
defer server.hmutex.Unlock()
if client.udpaddr != nil {
host := client.udpaddr.IP.String()
oldclients := server.hclients[host]
newclients := []*Client{}
for _, hostclient := range oldclients {
if hostclient != client {
newclients = append(newclients, hostclient)
}
}
server.hclients[host] = newclients
server.hpclients[client.udpaddr.String()] = nil, false
}
server.clients[client.Session] = nil, false
}
// This is the synchronous handler goroutine. // This is the synchronous handler goroutine.
// Important control channel messages are routed through this Goroutine // Important control channel messages are routed through this Goroutine
// to keep server state synchronized. // to keep server state synchronized.
@ -198,7 +218,6 @@ func (server *Server) handleAuthenticate(client *Client, msg *Message) {
server.updateCodecVersions() server.updateCodecVersions()
client.sendChannelList() client.sendChannelList()
client.state = StateClientAuthenticated client.state = StateClientAuthenticated
// Add the client to the connected list // Add the client to the connected list
@ -299,7 +318,6 @@ func (server *Server) updateCodecVersions() {
} }
log.Printf("CELT codec switch %v %v (PreferAlpha %v)", server.AlphaCodec, server.BetaCodec, server.PreferAlphaCodec) log.Printf("CELT codec switch %v %v (PreferAlpha %v)", server.AlphaCodec, server.BetaCodec, server.PreferAlphaCodec)
return return
} }
@ -467,8 +485,7 @@ func (server *Server) ListenUDP() {
// //
// If we don't find any matches, we look in the 'hclients', // If we don't find any matches, we look in the 'hclients',
// which maps a host address to a slice of clients. // which maps a host address to a slice of clients.
server.hmutex.RLock() server.hmutex.Lock()
defer server.hmutex.RUnlock()
client, ok := server.hpclients[udpaddr.String()] client, ok := server.hpclients[udpaddr.String()]
if ok { if ok {
err = client.crypt.Decrypt(buf[0:nread], plain[0:]) err = client.crypt.Decrypt(buf[0:nread], plain[0:])
@ -478,7 +495,6 @@ func (server *Server) ListenUDP() {
match = client match = client
} else { } else {
host := udpaddr.IP.String() host := udpaddr.IP.String()
server.hmutex.RLock()
hostclients := server.hclients[host] hostclients := server.hclients[host]
for _, client := range hostclients { for _, client := range hostclients {
err = client.crypt.Decrypt(buf[0:nread], plain[0:]) err = client.crypt.Decrypt(buf[0:nread], plain[0:])
@ -489,6 +505,7 @@ func (server *Server) ListenUDP() {
} }
} }
} }
server.hmutex.Unlock()
// No client found. // No client found.
if match == nil { if match == nil {