Async authentication. Improved disconnect handling.

This commit is contained in:
Mikkel Krautz 2011-04-10 02:18:56 +02:00
parent be73cceb8b
commit 3014bf7e3d
2 changed files with 179 additions and 54 deletions

135
client.go
View file

@ -28,8 +28,9 @@ type Client struct {
state int state int
server *Server server *Server
msgchan chan *Message msgchan chan *Message
udprecv chan []byte udprecv chan []byte
doneSending chan bool
disconnected bool disconnected bool
@ -46,6 +47,13 @@ type Client struct {
// the user data store, so we have to keep track of it separately. // the user data store, so we have to keep track of it separately.
superUser bool superUser bool
// The clientReady channel signals the client's reciever routine that
// the client has been successfully authenticated and that it has been
// sent the necessary information to be a participant on the server.
// When this signal is received, the client has transitioned into the
// 'ready' state.
clientReady chan bool
// Version // Version
Version uint32 Version uint32
ClientName string ClientName string
@ -123,6 +131,23 @@ func (client *Client) disconnect(kicked bool) {
if !client.disconnected { if !client.disconnected {
client.disconnected = true client.disconnected = true
close(client.udprecv) close(client.udprecv)
// If the client paniced during authentication, before reaching
// the ready state, the receiver goroutine will be waiting for
// a signal telling it that the client is ready to receive 'real'
// messages from the server.
//
// In case of a premature disconnect, close the channel so the
// receiver routine can exit correctly.
if client.state == StateClientSentVersion || client.state == StateClientAuthenticated {
close(client.clientReady)
}
// Cleanly shut down the sender goroutine. This should be non-blocking
// since we're writing to a bufio.Writer.
// todo(mkrautz): Check whether that's the case? We do a flush, so maybe not.
client.msgchan <- nil
<-client.doneSending
close(client.msgchan) close(client.msgchan)
client.conn.Close() client.conn.Close()
@ -259,7 +284,8 @@ func (c *Client) sendPermissionDeniedFallback(kind string, version uint32, text
// UDP receiver. // UDP receiver.
func (client *Client) udpreceiver() { func (client *Client) udpreceiver() {
for buf := range client.udprecv { for buf := range client.udprecv {
// Channel close. // Received a zero-valued buffer. This means that the udprecv
// channel was closed, so exit cleanly.
if len(buf) == 0 { if len(buf) == 0 {
return return
} }
@ -328,42 +354,60 @@ func (client *Client) sendUdp(msg *Message) {
} }
} }
// Send a Message to the client. The Message in msg to the client's
// buffered writer and flushes it when done.
//
// This method should only be called from within the client's own
// sender goroutine, since it serializes access to the underlying
// buffered writer.
func (client *Client) sendMessage(msg *Message) os.Error {
// Write message kind
err := binary.Write(client.writer, binary.BigEndian, msg.kind)
if err != nil {
return err
}
// Message length
err = binary.Write(client.writer, binary.BigEndian, uint32(len(msg.buf)))
if err != nil {
return err
}
// Message buffer itself
_, err = client.writer.Write(msg.buf)
if err != nil {
return err
}
// Flush it, no need to keep it in the buffer for any longer.
err = client.writer.Flush()
if err != nil {
return err
}
return nil
}
// Sender Goroutine. The sender goroutine will initiate shutdown
// if it receives a nil Message.
// //
// Sender Goroutine // On shutdown, it will send a true boolean value on the client's
// // doneSending channel. This allows the client to send all the messages
// that remain in it's buffer when the server has to force a disconnect.
func (client *Client) sender() { func (client *Client) sender() {
defer func() {
client.doneSending <- true
}()
for msg := range client.msgchan { for msg := range client.msgchan {
// Check for channel close. if msg == nil {
if len(msg.buf) == 0 {
return return
} }
// First, we write out the message type as a big-endian uint16 err := client.sendMessage(msg)
err := binary.Write(client.writer, binary.BigEndian, msg.kind)
if err != nil { if err != nil {
client.Panic("Unable to write message type to client") // fixme(mkrautz): This is a deadlock waiting to happen.
return client.Panic("Unable to send message to client")
}
// Then the length of the protobuf message
err = binary.Write(client.writer, binary.BigEndian, uint32(len(msg.buf)))
if err != nil {
client.Panic("Unable to write message length to client")
return
}
// At last, write the buffer itself
_, err = client.writer.Write(msg.buf)
if err != nil {
client.Panic("Unable to write message content to client")
return
}
// Flush the write buffer
err = client.writer.Flush()
if err != nil {
client.Panic("Unable to flush client write buffer")
return return
} }
} }
@ -372,8 +416,9 @@ func (client *Client) sender() {
// Receiver Goroutine // Receiver Goroutine
func (client *Client) receiver() { func (client *Client) receiver() {
for { for {
// The version handshake is done. Forward this message to the synchronous request handler. // The version handshake is done, the client has been authenticated and it has received
if client.state == StateClientAuthenticated || client.state == StateClientSentVersion { // all necessary information regarding the server. Now we're ready to roll!
if client.state == StateClientReady {
// Try to read the next message in the pool // Try to read the next message in the pool
msg, err := client.readProtoMessage() msg, err := client.readProtoMessage()
if err != nil { if err != nil {
@ -395,6 +440,30 @@ func (client *Client) receiver() {
} }
} }
// The client has responded to our version query. It will try to authenticate.
if client.state == StateClientSentVersion {
// Try to read the next message in the pool
msg, err := client.readProtoMessage()
if err != nil {
client.Panic(err.String())
return
}
client.clientReady = make(chan bool)
go client.server.handleAuthenticate(client, msg)
<-client.clientReady
// It's possible that the client has disconnected in the meantime.
// In that case, step out of the receiver, since there's nothing left
// to receive.
if client.disconnected {
return
}
close(client.clientReady)
client.clientReady = nil
}
// The client has just connected. Before it sends its authentication // The client has just connected. Before it sends its authentication
// information we must send it our version information so it knows // information we must send it our version information so it knows
// what version of the protocol it should speak. // what version of the protocol it should speak.

View file

@ -32,6 +32,7 @@ const (
StateServerSentVersion StateServerSentVersion
StateClientSentVersion StateClientSentVersion
StateClientAuthenticated StateClientAuthenticated
StateClientReady
StateClientDead StateClientDead
) )
@ -46,6 +47,10 @@ type Server struct {
incoming chan *Message incoming chan *Message
udpsend chan *Message udpsend chan *Message
voicebroadcast chan *VoiceBroadcast voicebroadcast chan *VoiceBroadcast
usercheck chan *userCheck
// Signals to the server that a client has been successfully
// authenticated.
clientAuthenticated chan *Client
// Config-related // Config-related
MaxUsers int MaxUsers int
@ -80,6 +85,12 @@ type Server struct {
aclcache ACLCache aclcache ACLCache
} }
type userCheck struct {
done chan bool
UserId int
Addr string
}
// Allocate a new Murmur instance // Allocate a new Murmur instance
func NewServer(id int64, addr string, port int) (s *Server, err os.Error) { func NewServer(id int64, addr string, port int) (s *Server, err os.Error) {
s = new(Server) s = new(Server)
@ -99,6 +110,8 @@ func NewServer(id int64, addr string, port int) (s *Server, err os.Error) {
s.incoming = make(chan *Message) s.incoming = make(chan *Message)
s.udpsend = make(chan *Message) s.udpsend = make(chan *Message)
s.voicebroadcast = make(chan *VoiceBroadcast) s.voicebroadcast = make(chan *VoiceBroadcast)
s.usercheck = make(chan *userCheck)
s.clientAuthenticated = make(chan *Client)
s.MaxBandwidth = 300000 s.MaxBandwidth = 300000
s.MaxUsers = 10 s.MaxUsers = 10
@ -169,6 +182,8 @@ func (server *Server) NewClient(conn net.Conn) (err os.Error) {
go client.receiver() go client.receiver()
go client.udpreceiver() go client.udpreceiver()
client.doneSending = make(chan bool)
go client.sender() go client.sender()
return return
@ -203,7 +218,7 @@ func (server *Server) RemoveClient(client *Client, kicked bool) {
// If the user was not kicked, broadcast a UserRemove message. // If the user was not kicked, broadcast a UserRemove message.
// If the user is disconnect via a kick, the UserRemove message has already been sent // If the user is disconnect via a kick, the UserRemove message has already been sent
// at this point. // at this point.
if !kicked { if !kicked && client.state > StateClientAuthenticated {
err := server.broadcastProtoMessage(MessageUserRemove, &mumbleproto.UserRemove{ err := server.broadcastProtoMessage(MessageUserRemove, &mumbleproto.UserRemove{
Session: proto.Uint32(client.Session), Session: proto.Uint32(client.Session),
}) })
@ -269,11 +284,7 @@ func (server *Server) handler() {
// Control channel messages // Control channel messages
case msg := <-server.incoming: case msg := <-server.incoming:
client := msg.client client := msg.client
if client.state == StateClientAuthenticated { server.handleIncomingMessage(client, msg)
server.handleIncomingMessage(client, msg)
} else if client.state == StateClientSentVersion {
server.handleAuthenticate(client, msg)
}
// Voice broadcast // Voice broadcast
case vb := <-server.voicebroadcast: case vb := <-server.voicebroadcast:
log.Printf("VoiceBroadcast!") log.Printf("VoiceBroadcast!")
@ -288,11 +299,34 @@ func (server *Server) handler() {
} }
} }
} }
// Finish client authentication. Send post-authentication
// server info.
case client := <-server.clientAuthenticated:
server.finishAuthenticate(client)
// User checking
case checker := <-server.usercheck:
found := false
for _, client := range server.clients {
if client.UserId() == checker.UserId {
checker.Addr = client.tcpaddr.String()
checker.done <- true
found = true
break
}
}
if !found {
checker.done <- false
}
} }
} }
} }
// Handle a Authenticate protobuf message. // Handle an Authenticate protobuf message. This is handled in a separate
// goroutine to allow for remote authenticators that are slow to respond.
//
// Once a user has been authenticated, it will ping the server's handler
// routine, which will call the finishAuthenticate method on Server which
// will send the channel tree, user list, etc. to the client.
func (server *Server) handleAuthenticate(client *Client, msg *Message) { func (server *Server) handleAuthenticate(client *Client, msg *Message) {
// Is this message not an authenticate message? If not, discard it... // Is this message not an authenticate message? If not, discard it...
if msg.kind != MessageAuthenticate { if msg.kind != MessageAuthenticate {
@ -309,7 +343,7 @@ func (server *Server) handleAuthenticate(client *Client, msg *Message) {
// Did we get a username? // Did we get a username?
if auth.Username == nil { if auth.Username == nil {
client.Panic("No username in auth message...") client.RejectAuth("InvalidUsername", "Please specify a username to log in")
return return
} }
@ -318,7 +352,7 @@ func (server *Server) handleAuthenticate(client *Client, msg *Message) {
// Extract certhash // Extract certhash
tlsconn, ok := client.conn.(*tls.Conn) tlsconn, ok := client.conn.(*tls.Conn)
if !ok { if !ok {
client.Panic("Type assertion failed") client.Panic("Invalid connection")
return return
} }
state := tlsconn.ConnectionState() state := tlsconn.ConnectionState()
@ -363,7 +397,20 @@ func (server *Server) handleAuthenticate(client *Client, msg *Message) {
// Found a user for this guy // Found a user for this guy
if client.user != nil { if client.user != nil {
log.Printf("Client authenticated as %v", client.user.Name) // Ask the server whether someone's already connecting using that user.
// This is a request to the Server's synchronous handler routine (the
// only routine that is guaranteed correct access to the internal client
// data).
checker := &userCheck{make(chan bool), int(client.user.Id), ""}
server.usercheck <- checker
foundUser := <-checker.done
if foundUser {
// todo(mkrautz): Murmur allows reconnects from same IP. That's pretty useful.
client.RejectAuth("UsernameInUse", "Someone else is already connected as this user")
return
} else {
log.Printf("Client authenticated as %v", client.user.Name)
}
} }
} }
@ -390,20 +437,28 @@ func (server *Server) handleAuthenticate(client *Client, msg *Message) {
client.Panic(err.String()) client.Panic(err.String())
} }
// Add the client to the connected list
server.session += 1
client.Session = server.session
server.clients[client.Session] = client
// Add codecs // Add codecs
client.codecs = auth.CeltVersions client.codecs = auth.CeltVersions
if len(client.codecs) == 0 { if len(client.codecs) == 0 {
log.Printf("Client %i connected without CELT codecs.", client.Session) log.Printf("Client %i connected without CELT codecs.", client.Session)
} }
client.state = StateClientAuthenticated
server.clientAuthenticated <- client
}
func (server *Server) finishAuthenticate(client *Client) {
// Add the client to the connected list
client.Session = server.session
server.clients[client.Session] = client
log.Printf("Assigned client session=%v", client.Session)
server.session += 1
// First, check whether we need to tell the other connected
// clients to switch to a codec so the new guy can actually speak.
server.updateCodecVersions() server.updateCodecVersions()
client.sendChannelList() client.sendChannelList()
client.state = StateClientAuthenticated
// Add the client to the host slice for its host address. // Add the client to the host slice for its host address.
host := client.tcpaddr.IP.String() host := client.tcpaddr.IP.String()
@ -441,12 +496,12 @@ func (server *Server) handleAuthenticate(client *Client, msg *Message) {
perm.ClearCacheBit() perm.ClearCacheBit()
sync.Permissions = proto.Uint64(uint64(perm)) sync.Permissions = proto.Uint64(uint64(perm))
} }
if err = client.sendProtoMessage(MessageServerSync, sync); err != nil { if err := client.sendProtoMessage(MessageServerSync, sync); err != nil {
client.Panic(err.String()) client.Panic(err.String())
return return
} }
err = client.sendProtoMessage(MessageServerConfig, &mumbleproto.ServerConfig{ err := client.sendProtoMessage(MessageServerConfig, &mumbleproto.ServerConfig{
AllowHtml: proto.Bool(true), AllowHtml: proto.Bool(true),
MessageLength: proto.Uint32(1000), MessageLength: proto.Uint32(1000),
ImageMessageLength: proto.Uint32(1000), ImageMessageLength: proto.Uint32(1000),
@ -456,7 +511,8 @@ func (server *Server) handleAuthenticate(client *Client, msg *Message) {
return return
} }
client.state = StateClientAuthenticated client.state = StateClientReady
client.clientReady <- true
} }
func (server *Server) updateCodecVersions() { func (server *Server) updateCodecVersions() {
@ -519,7 +575,7 @@ func (server *Server) updateCodecVersions() {
func (server *Server) sendUserList(client *Client) { func (server *Server) sendUserList(client *Client) {
for _, user := range server.clients { for _, user := range server.clients {
if user.state != StateClientAuthenticated { if user.state != StateClientReady {
continue continue
} }
if user == client { if user == client {
@ -568,7 +624,7 @@ func (server *Server) broadcastProtoMessageWithPredicate(kind uint16, msg interf
if !clientcheck(client) { if !clientcheck(client) {
continue continue
} }
if client.state != StateClientAuthenticated { if client.state < StateClientAuthenticated {
continue continue
} }
err := client.sendProtoMessage(kind, msg) err := client.sendProtoMessage(kind, msg)