1
0
Fork 0
forked from External/grumble

Use a map instead of slices for storing client pointers in the Server struct.

This commit is contained in:
Mikkel Krautz 2010-11-20 00:55:51 +01:00
parent 122b6af163
commit a57908b487
3 changed files with 99 additions and 101 deletions

View file

@ -49,6 +49,12 @@ func (client *Client) Panic(reason string) {
// so it can perform a proper disconnect. // so it can perform a proper disconnect.
} }
func (client *Client) Disconnect() {
client.disconnected = true
close(client.udprecv)
close(client.msgchan)
}
// Read a protobuf message from a client // Read a protobuf message from a client
func (client *Client) readProtoMessage() (msg *Message, err os.Error) { func (client *Client) readProtoMessage() (msg *Message, err os.Error) {
var length uint32 var length uint32
@ -101,8 +107,13 @@ func (c *Client) sendProtoMessage(kind uint16, msg interface{}) (err os.Error) {
// UDP receiver. // UDP receiver.
func (client *Client) udpreceiver() { func (client *Client) udpreceiver() {
for { for buf := range client.udprecv {
buf := <-client.udprecv
// Channel close.
if len(buf) == 0 {
return
}
kind := (buf[0] >> 5) & 0x07; kind := (buf[0] >> 5) & 0x07;
switch kind { switch kind {
@ -162,8 +173,11 @@ func (client *Client) sendUdp(msg *Message) {
// Sender Goroutine // Sender Goroutine
// //
func (client *Client) sender() { func (client *Client) sender() {
for { for msg := range client.msgchan {
msg := <-client.msgchan // Check for channel close.
if len(msg.buf) == 0 {
return
}
// First, we write out the message type as a big-endian uint16 // First, we write out the message type as a big-endian uint16
err := binary.Write(client.writer, binary.BigEndian, msg.kind) err := binary.Write(client.writer, binary.BigEndian, msg.kind)
@ -198,12 +212,17 @@ func (client *Client) sender() {
// Receiver Goroutine // Receiver Goroutine
func (client *Client) receiver() { func (client *Client) receiver() {
for { for {
// The version handshake is done. Forward this message to the synchronous request handler. // The version handshake is done. Forward this message to the synchronous request handler.
if client.state == StateClientAuthenticated || client.state == StateClientSentVersion { if client.state == StateClientAuthenticated || client.state == StateClientSentVersion {
// Try to read the next message in the pool // Try to read the next message in the pool
msg, err := client.readProtoMessage() msg, err := client.readProtoMessage()
if err != nil { if err != nil {
if err == os.EOF {
log.Printf("Client disconnected.")
client.Disconnect()
} else {
log.Printf("Client error.")
}
return return
} }
// Special case UDPTunnel messages. They're high priority and shouldn't // Special case UDPTunnel messages. They're high priority and shouldn't
@ -231,6 +250,12 @@ func (client *Client) receiver() {
} else if client.state == StateServerSentVersion { } else if client.state == StateServerSentVersion {
msg, err := client.readProtoMessage() msg, err := client.readProtoMessage()
if err != nil { if err != nil {
if err == os.EOF {
log.Printf("Client disconnected.")
client.Disconnect()
} else {
log.Printf("Client error.")
}
return return
} }
@ -268,18 +293,14 @@ func (client *Client) sendChannelList() {
// Send the userlist to a client. // Send the userlist to a client.
func (client *Client) sendUserList() { func (client *Client) sendUserList() {
server := client.server server := client.server
for _, client := range server.clients {
server.cmutex.RLock() err := client.sendProtoMessage(MessageUserState, &mumbleproto.UserState{
defer server.cmutex.RUnlock()
for _, user := range server.clients {
err := user.sendProtoMessage(MessageUserState, &mumbleproto.UserState{
Session: proto.Uint32(client.Session), Session: proto.Uint32(client.Session),
Name: proto.String(client.Username), Name: proto.String(client.Username),
ChannelId: proto.Uint32(0), ChannelId: proto.Uint32(0),
}) })
if err != nil { if err != nil {
log.Printf("unable to send!") log.Printf("Unable to send UserList")
continue continue
} }
} }

View file

@ -57,7 +57,7 @@ type Message struct {
// is ignored for UDP packets. // is ignored for UDP packets.
kind uint16 kind uint16
// For UDP datagrams one of these fiels have to be filled out. // For UDP datagrams one of these fields have to be filled out.
// If there is no connection established, address must be used. // If there is no connection established, address must be used.
// If the datagram comes from an already-connected client, the // If the datagram comes from an already-connected client, the
// client field should point to that client. // client field should point to that client.
@ -143,7 +143,10 @@ func (server *Server) handleTextMessage(client *Client, msg *Message) {
users := []*Client{}; users := []*Client{};
for i := 0; i < len(txtmsg.Session); i++ { for i := 0; i < len(txtmsg.Session); i++ {
user := server.getClient(txtmsg.Session[i]) user, ok := server.clients[txtmsg.Session[i]]
if !ok {
log.Panic("Could not look up client by session")
}
users = append(users, user) users = append(users, user)
} }
@ -170,4 +173,5 @@ func (server *Server) handleUserStatsMessage(client *Client, msg *Message) {
if err != nil { if err != nil {
client.Panic(err.String()) client.Panic(err.String())
} }
log.Printf("UserStatsMessage")
} }

149
server.go
View file

@ -43,18 +43,19 @@ type Server struct {
incoming chan *Message incoming chan *Message
outgoing chan *Message outgoing chan *Message
udpsend chan *Message udpsend chan *Message
// Config-related // Config-related
MaxUsers int MaxUsers int
MaxBandwidth uint32 MaxBandwidth uint32
// Clients
session uint32 session uint32
clients map[uint32]*Client
// A list of all connected clients hmutex *sync.RWMutex
cmutex *sync.RWMutex hclients map[string][]*Client
clients []*Client hpclients map[string]*Client
// Codec information // Codec information
AlphaCodec int32 AlphaCodec int32
@ -81,8 +82,11 @@ func NewServer(addr string, port int) (s *Server, err os.Error) {
s.address = addr s.address = addr
s.port = port s.port = port
// Create the list of connected clients s.clients = make(map[uint32]*Client)
s.cmutex = new(sync.RWMutex)
s.hmutex = new(sync.RWMutex)
s.hclients = make(map[string][]*Client)
s.hpclients = make(map[string]*Client)
s.outgoing = make(chan *Message) s.outgoing = make(chan *Message)
s.incoming = make(chan *Message) s.incoming = make(chan *Message)
@ -91,8 +95,6 @@ func NewServer(addr string, port int) (s *Server, err os.Error) {
s.MaxBandwidth = 300000 s.MaxBandwidth = 300000
s.MaxUsers = 10 s.MaxUsers = 10
// Allocate the root channel
s.root = &Channel{ s.root = &Channel{
Id: 0, Id: 0,
Name: "Root", Name: "Root",
@ -107,13 +109,13 @@ func NewServer(addr string, port int) (s *Server, err os.Error) {
// Called by the server to initiate a new client connection. // Called by the server to initiate a new client connection.
func (server *Server) NewClient(conn net.Conn) (err os.Error) { func (server *Server) NewClient(conn net.Conn) (err os.Error) {
client := new(Client) client := new(Client)
addr := conn.RemoteAddr()
// Get the address of the connected client if addr == nil {
if addr := conn.RemoteAddr(); addr != nil { err = os.NewError("Unable to extract address for client.")
client.tcpaddr = addr.(*net.TCPAddr) return
log.Printf("client connected: %s", client.tcpaddr.String())
} }
client.tcpaddr = addr.(*net.TCPAddr)
client.server = server client.server = server
client.conn = conn client.conn = conn
client.reader = bufio.NewReader(client.conn) client.reader = bufio.NewReader(client.conn)
@ -123,15 +125,6 @@ func (server *Server) NewClient(conn net.Conn) (err os.Error) {
client.msgchan = make(chan *Message) client.msgchan = make(chan *Message)
client.udprecv = make(chan []byte) client.udprecv = make(chan []byte)
// New client connection....
server.session += 1
client.Session = server.session
// Add it to the list of connected clients
server.cmutex.Lock()
server.clients = append(server.clients, client)
server.cmutex.Unlock()
go client.receiver() go client.receiver()
go client.udpreceiver() go client.udpreceiver()
go client.sender() go client.sender()
@ -139,21 +132,9 @@ func (server *Server) NewClient(conn net.Conn) (err os.Error) {
return return
} }
// Lookup a client by it's session id. Optimize this by using a map. // This is the synchronous handler goroutine.
func (server *Server) getClient(session uint32) (client *Client) { // Important control channel messages are routed through this Goroutine
server.cmutex.RLock() // to keep server state synchronized.
defer server.cmutex.RUnlock()
for _, user := range server.clients {
if user.Session == session {
return user
}
}
return nil
}
// This is the synchronous request handler for all incoming messages.
func (server *Server) handler() { func (server *Server) handler() {
for { for {
msg := <-server.incoming msg := <-server.incoming
@ -167,6 +148,7 @@ func (server *Server) handler() {
} }
} }
// Handle a Authenticate protobuf message.
func (server *Server) handleAuthenticate(client *Client, msg *Message) { func (server *Server) handleAuthenticate(client *Client, msg *Message) {
// Is this message not an authenticate message? If not, discard it... // Is this message not an authenticate message? If not, discard it...
if msg.kind != MessageAuthenticate { if msg.kind != MessageAuthenticate {
@ -201,8 +183,8 @@ func (server *Server) handleAuthenticate(client *Client, msg *Message) {
return return
} }
// Send CryptState information to the client so it can establish an UDP connection // Send CryptState information to the client so it can establish an UDP connection,
// (if it wishes)... // if it wishes.
err = client.sendProtoMessage(MessageCryptSetup, &mumbleproto.CryptSetup{ err = client.sendProtoMessage(MessageCryptSetup, &mumbleproto.CryptSetup{
Key: client.crypt.RawKey[0:], Key: client.crypt.RawKey[0:],
ClientNonce: client.crypt.DecryptIV[0:], ClientNonce: client.crypt.DecryptIV[0:],
@ -219,7 +201,18 @@ func (server *Server) handleAuthenticate(client *Client, msg *Message) {
client.state = StateClientAuthenticated client.state = StateClientAuthenticated
// Broadcast that we, the client, entered a channel... // Add the client to the connected list
server.session += 1
client.Session = server.session
server.clients[client.Session] = client
// Add the client to the host slice for its host address.
host := client.tcpaddr.IP.String()
server.hmutex.Lock()
server.hclients[host] = append(server.hclients[host], client)
server.hmutex.Unlock()
// Broadcast the the user entered a channel
err = server.broadcastProtoMessage(MessageUserState, &mumbleproto.UserState{ err = server.broadcastProtoMessage(MessageUserState, &mumbleproto.UserState{
Session: proto.Uint32(client.Session), Session: proto.Uint32(client.Session),
Name: proto.String(client.Username), Name: proto.String(client.Username),
@ -258,9 +251,6 @@ func (server *Server) updateCodecVersions() {
var winner int32 var winner int32
var count int var count int
server.cmutex.RLock()
defer server.cmutex.RUnlock()
for _, client := range server.clients { for _, client := range server.clients {
for i := 0; i < len(client.codecs); i++ { for i := 0; i < len(client.codecs); i++ {
codecusers[client.codecs[i]] += 1 codecusers[client.codecs[i]] += 1
@ -314,9 +304,6 @@ func (server *Server) updateCodecVersions() {
} }
func (server *Server) sendUserList(client *Client) { func (server *Server) sendUserList(client *Client) {
server.cmutex.RLock()
defer server.cmutex.RUnlock()
for _, user := range server.clients { for _, user := range server.clients {
if user.state != StateClientAuthenticated { if user.state != StateClientAuthenticated {
continue continue
@ -339,9 +326,6 @@ func (server *Server) sendUserList(client *Client) {
} }
func (server *Server) broadcastProtoMessage(kind uint16, msg interface{}) (err os.Error) { func (server *Server) broadcastProtoMessage(kind uint16, msg interface{}) (err os.Error) {
server.cmutex.RLock()
defer server.cmutex.RUnlock()
for _, client := range server.clients { for _, client := range server.clients {
if client.state != StateClientAuthenticated { if client.state != StateClientAuthenticated {
continue continue
@ -476,55 +460,45 @@ func (server *Server) ListenUDP() {
} else { } else {
var match *Client var match *Client
plain := make([]byte, nread-4) plain := make([]byte, nread-4)
decrypted := false
// First, check if any of our clients match the net.UDPAddr... // Determine which client sent the the packet. First, we
server.cmutex.RLock() // check the map 'hpclients' in the server struct. It maps
for _, client := range server.clients { // a hort-post combination to a client.
if client.udpaddr.String() == udpaddr.String() { //
match = client // If we don't find any matches, we look in the 'hclients',
} // which maps a host address to a slice of clients.
} server.hmutex.RLock()
server.cmutex.RUnlock() defer server.hmutex.RUnlock()
client, ok := server.hpclients[udpaddr.String()]
// No matching client found. We must try to decrypt... if ok {
if match == nil {
server.cmutex.RLock()
for _, client := range server.clients {
// Try to decrypt.
err = client.crypt.Decrypt(buf[0:nread], plain[0:]) err = client.crypt.Decrypt(buf[0:nread], plain[0:])
if err != nil { if err != nil {
// Decryption failed. Try another client... log.Panicf("Unable to decrypt incoming packet for client %v (host-port matched)", client)
continue
} }
// Decryption succeeded.
decrypted = true
// If we were able to successfully decrpyt, add
// the UDPAddr to the Client struct.
log.Printf("Client UDP connection established.")
client.udpaddr = remote.(*net.UDPAddr)
match = client match = client
} else {
break host := udpaddr.IP.String()
server.hmutex.RLock()
hostclients := server.hclients[host]
for _, client := range hostclients {
err = client.crypt.Decrypt(buf[0:nread], plain[0:])
if err != nil {
continue
} else {
match = client
}
} }
server.cmutex.RUnlock()
} }
// We were not able to find a client that could decrypt the incoming // No client found.
// packet. Log it?
if match == nil { if match == nil {
log.Printf("No match found for packet. Discarding...")
continue continue
} }
if !decrypted { if match.udpaddr == nil {
err = match.crypt.Decrypt(buf[0:nread], plain[0:]) match.udpaddr = udpaddr
if err != nil {
log.Printf("Unable to decrypt from client..")
} }
}
match.udp = true match.udp = true
match.udprecv <- plain match.udprecv <- plain
} }
@ -553,16 +527,15 @@ func (s *Server) ListenAndMurmur() {
// when we do get a new connection, we spawn // when we do get a new connection, we spawn
// a new Go-routine to handle the client. // a new Go-routine to handle the client.
for { for {
// New client connected // New client connected
conn, err := l.Accept() conn, err := l.Accept()
if err != nil { if err != nil {
log.Printf("unable to accept()") log.Printf("Unable to accept() new client.")
} }
tls, ok := conn.(*tls.Conn) tls, ok := conn.(*tls.Conn)
if !ok { if !ok {
log.Printf("Not tls :(") log.Panic("Internal inconsistency error.")
} }
// Force the TLS handshake to get going. We'd like // Force the TLS handshake to get going. We'd like