forked from External/grumble
Use a map instead of slices for storing client pointers in the Server struct.
This commit is contained in:
parent
122b6af163
commit
a57908b487
3 changed files with 99 additions and 101 deletions
45
client.go
45
client.go
|
|
@ -49,6 +49,12 @@ func (client *Client) Panic(reason string) {
|
||||||
// so it can perform a proper disconnect.
|
// so it can perform a proper disconnect.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (client *Client) Disconnect() {
|
||||||
|
client.disconnected = true
|
||||||
|
close(client.udprecv)
|
||||||
|
close(client.msgchan)
|
||||||
|
}
|
||||||
|
|
||||||
// Read a protobuf message from a client
|
// Read a protobuf message from a client
|
||||||
func (client *Client) readProtoMessage() (msg *Message, err os.Error) {
|
func (client *Client) readProtoMessage() (msg *Message, err os.Error) {
|
||||||
var length uint32
|
var length uint32
|
||||||
|
|
@ -101,8 +107,13 @@ func (c *Client) sendProtoMessage(kind uint16, msg interface{}) (err os.Error) {
|
||||||
|
|
||||||
// UDP receiver.
|
// UDP receiver.
|
||||||
func (client *Client) udpreceiver() {
|
func (client *Client) udpreceiver() {
|
||||||
for {
|
for buf := range client.udprecv {
|
||||||
buf := <-client.udprecv
|
|
||||||
|
// Channel close.
|
||||||
|
if len(buf) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
kind := (buf[0] >> 5) & 0x07;
|
kind := (buf[0] >> 5) & 0x07;
|
||||||
|
|
||||||
switch kind {
|
switch kind {
|
||||||
|
|
@ -162,8 +173,11 @@ func (client *Client) sendUdp(msg *Message) {
|
||||||
// Sender Goroutine
|
// Sender Goroutine
|
||||||
//
|
//
|
||||||
func (client *Client) sender() {
|
func (client *Client) sender() {
|
||||||
for {
|
for msg := range client.msgchan {
|
||||||
msg := <-client.msgchan
|
// Check for channel close.
|
||||||
|
if len(msg.buf) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// First, we write out the message type as a big-endian uint16
|
// First, we write out the message type as a big-endian uint16
|
||||||
err := binary.Write(client.writer, binary.BigEndian, msg.kind)
|
err := binary.Write(client.writer, binary.BigEndian, msg.kind)
|
||||||
|
|
@ -198,12 +212,17 @@ func (client *Client) sender() {
|
||||||
// Receiver Goroutine
|
// Receiver Goroutine
|
||||||
func (client *Client) receiver() {
|
func (client *Client) receiver() {
|
||||||
for {
|
for {
|
||||||
|
|
||||||
// The version handshake is done. Forward this message to the synchronous request handler.
|
// The version handshake is done. Forward this message to the synchronous request handler.
|
||||||
if client.state == StateClientAuthenticated || client.state == StateClientSentVersion {
|
if client.state == StateClientAuthenticated || client.state == StateClientSentVersion {
|
||||||
// Try to read the next message in the pool
|
// Try to read the next message in the pool
|
||||||
msg, err := client.readProtoMessage()
|
msg, err := client.readProtoMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if err == os.EOF {
|
||||||
|
log.Printf("Client disconnected.")
|
||||||
|
client.Disconnect()
|
||||||
|
} else {
|
||||||
|
log.Printf("Client error.")
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Special case UDPTunnel messages. They're high priority and shouldn't
|
// Special case UDPTunnel messages. They're high priority and shouldn't
|
||||||
|
|
@ -231,6 +250,12 @@ func (client *Client) receiver() {
|
||||||
} else if client.state == StateServerSentVersion {
|
} else if client.state == StateServerSentVersion {
|
||||||
msg, err := client.readProtoMessage()
|
msg, err := client.readProtoMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if err == os.EOF {
|
||||||
|
log.Printf("Client disconnected.")
|
||||||
|
client.Disconnect()
|
||||||
|
} else {
|
||||||
|
log.Printf("Client error.")
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -268,18 +293,14 @@ func (client *Client) sendChannelList() {
|
||||||
// Send the userlist to a client.
|
// Send the userlist to a client.
|
||||||
func (client *Client) sendUserList() {
|
func (client *Client) sendUserList() {
|
||||||
server := client.server
|
server := client.server
|
||||||
|
for _, client := range server.clients {
|
||||||
server.cmutex.RLock()
|
err := client.sendProtoMessage(MessageUserState, &mumbleproto.UserState{
|
||||||
defer server.cmutex.RUnlock()
|
|
||||||
|
|
||||||
for _, user := range server.clients {
|
|
||||||
err := user.sendProtoMessage(MessageUserState, &mumbleproto.UserState{
|
|
||||||
Session: proto.Uint32(client.Session),
|
Session: proto.Uint32(client.Session),
|
||||||
Name: proto.String(client.Username),
|
Name: proto.String(client.Username),
|
||||||
ChannelId: proto.Uint32(0),
|
ChannelId: proto.Uint32(0),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("unable to send!")
|
log.Printf("Unable to send UserList")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,7 @@ type Message struct {
|
||||||
// is ignored for UDP packets.
|
// is ignored for UDP packets.
|
||||||
kind uint16
|
kind uint16
|
||||||
|
|
||||||
// For UDP datagrams one of these fiels have to be filled out.
|
// For UDP datagrams one of these fields have to be filled out.
|
||||||
// If there is no connection established, address must be used.
|
// If there is no connection established, address must be used.
|
||||||
// If the datagram comes from an already-connected client, the
|
// If the datagram comes from an already-connected client, the
|
||||||
// client field should point to that client.
|
// client field should point to that client.
|
||||||
|
|
@ -143,7 +143,10 @@ func (server *Server) handleTextMessage(client *Client, msg *Message) {
|
||||||
|
|
||||||
users := []*Client{};
|
users := []*Client{};
|
||||||
for i := 0; i < len(txtmsg.Session); i++ {
|
for i := 0; i < len(txtmsg.Session); i++ {
|
||||||
user := server.getClient(txtmsg.Session[i])
|
user, ok := server.clients[txtmsg.Session[i]]
|
||||||
|
if !ok {
|
||||||
|
log.Panic("Could not look up client by session")
|
||||||
|
}
|
||||||
users = append(users, user)
|
users = append(users, user)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -170,4 +173,5 @@ func (server *Server) handleUserStatsMessage(client *Client, msg *Message) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
client.Panic(err.String())
|
client.Panic(err.String())
|
||||||
}
|
}
|
||||||
|
log.Printf("UserStatsMessage")
|
||||||
}
|
}
|
||||||
|
|
|
||||||
149
server.go
149
server.go
|
|
@ -43,18 +43,19 @@ type Server struct {
|
||||||
|
|
||||||
incoming chan *Message
|
incoming chan *Message
|
||||||
outgoing chan *Message
|
outgoing chan *Message
|
||||||
|
|
||||||
udpsend chan *Message
|
udpsend chan *Message
|
||||||
|
|
||||||
// Config-related
|
// Config-related
|
||||||
MaxUsers int
|
MaxUsers int
|
||||||
MaxBandwidth uint32
|
MaxBandwidth uint32
|
||||||
|
|
||||||
|
// Clients
|
||||||
session uint32
|
session uint32
|
||||||
|
clients map[uint32]*Client
|
||||||
|
|
||||||
// A list of all connected clients
|
hmutex *sync.RWMutex
|
||||||
cmutex *sync.RWMutex
|
hclients map[string][]*Client
|
||||||
clients []*Client
|
hpclients map[string]*Client
|
||||||
|
|
||||||
// Codec information
|
// Codec information
|
||||||
AlphaCodec int32
|
AlphaCodec int32
|
||||||
|
|
@ -81,8 +82,11 @@ func NewServer(addr string, port int) (s *Server, err os.Error) {
|
||||||
s.address = addr
|
s.address = addr
|
||||||
s.port = port
|
s.port = port
|
||||||
|
|
||||||
// Create the list of connected clients
|
s.clients = make(map[uint32]*Client)
|
||||||
s.cmutex = new(sync.RWMutex)
|
|
||||||
|
s.hmutex = new(sync.RWMutex)
|
||||||
|
s.hclients = make(map[string][]*Client)
|
||||||
|
s.hpclients = make(map[string]*Client)
|
||||||
|
|
||||||
s.outgoing = make(chan *Message)
|
s.outgoing = make(chan *Message)
|
||||||
s.incoming = make(chan *Message)
|
s.incoming = make(chan *Message)
|
||||||
|
|
@ -91,8 +95,6 @@ func NewServer(addr string, port int) (s *Server, err os.Error) {
|
||||||
s.MaxBandwidth = 300000
|
s.MaxBandwidth = 300000
|
||||||
s.MaxUsers = 10
|
s.MaxUsers = 10
|
||||||
|
|
||||||
// Allocate the root channel
|
|
||||||
|
|
||||||
s.root = &Channel{
|
s.root = &Channel{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Name: "Root",
|
Name: "Root",
|
||||||
|
|
@ -107,13 +109,13 @@ func NewServer(addr string, port int) (s *Server, err os.Error) {
|
||||||
// Called by the server to initiate a new client connection.
|
// Called by the server to initiate a new client connection.
|
||||||
func (server *Server) NewClient(conn net.Conn) (err os.Error) {
|
func (server *Server) NewClient(conn net.Conn) (err os.Error) {
|
||||||
client := new(Client)
|
client := new(Client)
|
||||||
|
addr := conn.RemoteAddr()
|
||||||
// Get the address of the connected client
|
if addr == nil {
|
||||||
if addr := conn.RemoteAddr(); addr != nil {
|
err = os.NewError("Unable to extract address for client.")
|
||||||
client.tcpaddr = addr.(*net.TCPAddr)
|
return
|
||||||
log.Printf("client connected: %s", client.tcpaddr.String())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
client.tcpaddr = addr.(*net.TCPAddr)
|
||||||
client.server = server
|
client.server = server
|
||||||
client.conn = conn
|
client.conn = conn
|
||||||
client.reader = bufio.NewReader(client.conn)
|
client.reader = bufio.NewReader(client.conn)
|
||||||
|
|
@ -123,15 +125,6 @@ func (server *Server) NewClient(conn net.Conn) (err os.Error) {
|
||||||
client.msgchan = make(chan *Message)
|
client.msgchan = make(chan *Message)
|
||||||
client.udprecv = make(chan []byte)
|
client.udprecv = make(chan []byte)
|
||||||
|
|
||||||
// New client connection....
|
|
||||||
server.session += 1
|
|
||||||
client.Session = server.session
|
|
||||||
|
|
||||||
// Add it to the list of connected clients
|
|
||||||
server.cmutex.Lock()
|
|
||||||
server.clients = append(server.clients, client)
|
|
||||||
server.cmutex.Unlock()
|
|
||||||
|
|
||||||
go client.receiver()
|
go client.receiver()
|
||||||
go client.udpreceiver()
|
go client.udpreceiver()
|
||||||
go client.sender()
|
go client.sender()
|
||||||
|
|
@ -139,21 +132,9 @@ func (server *Server) NewClient(conn net.Conn) (err os.Error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Lookup a client by it's session id. Optimize this by using a map.
|
// This is the synchronous handler goroutine.
|
||||||
func (server *Server) getClient(session uint32) (client *Client) {
|
// Important control channel messages are routed through this Goroutine
|
||||||
server.cmutex.RLock()
|
// to keep server state synchronized.
|
||||||
defer server.cmutex.RUnlock()
|
|
||||||
|
|
||||||
for _, user := range server.clients {
|
|
||||||
if user.Session == session {
|
|
||||||
return user
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// This is the synchronous request handler for all incoming messages.
|
|
||||||
func (server *Server) handler() {
|
func (server *Server) handler() {
|
||||||
for {
|
for {
|
||||||
msg := <-server.incoming
|
msg := <-server.incoming
|
||||||
|
|
@ -167,6 +148,7 @@ func (server *Server) handler() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle a Authenticate protobuf message.
|
||||||
func (server *Server) handleAuthenticate(client *Client, msg *Message) {
|
func (server *Server) handleAuthenticate(client *Client, msg *Message) {
|
||||||
// Is this message not an authenticate message? If not, discard it...
|
// Is this message not an authenticate message? If not, discard it...
|
||||||
if msg.kind != MessageAuthenticate {
|
if msg.kind != MessageAuthenticate {
|
||||||
|
|
@ -201,8 +183,8 @@ func (server *Server) handleAuthenticate(client *Client, msg *Message) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send CryptState information to the client so it can establish an UDP connection
|
// Send CryptState information to the client so it can establish an UDP connection,
|
||||||
// (if it wishes)...
|
// if it wishes.
|
||||||
err = client.sendProtoMessage(MessageCryptSetup, &mumbleproto.CryptSetup{
|
err = client.sendProtoMessage(MessageCryptSetup, &mumbleproto.CryptSetup{
|
||||||
Key: client.crypt.RawKey[0:],
|
Key: client.crypt.RawKey[0:],
|
||||||
ClientNonce: client.crypt.DecryptIV[0:],
|
ClientNonce: client.crypt.DecryptIV[0:],
|
||||||
|
|
@ -219,7 +201,18 @@ func (server *Server) handleAuthenticate(client *Client, msg *Message) {
|
||||||
|
|
||||||
client.state = StateClientAuthenticated
|
client.state = StateClientAuthenticated
|
||||||
|
|
||||||
// Broadcast that we, the client, entered a channel...
|
// Add the client to the connected list
|
||||||
|
server.session += 1
|
||||||
|
client.Session = server.session
|
||||||
|
server.clients[client.Session] = client
|
||||||
|
|
||||||
|
// Add the client to the host slice for its host address.
|
||||||
|
host := client.tcpaddr.IP.String()
|
||||||
|
server.hmutex.Lock()
|
||||||
|
server.hclients[host] = append(server.hclients[host], client)
|
||||||
|
server.hmutex.Unlock()
|
||||||
|
|
||||||
|
// Broadcast the the user entered a channel
|
||||||
err = server.broadcastProtoMessage(MessageUserState, &mumbleproto.UserState{
|
err = server.broadcastProtoMessage(MessageUserState, &mumbleproto.UserState{
|
||||||
Session: proto.Uint32(client.Session),
|
Session: proto.Uint32(client.Session),
|
||||||
Name: proto.String(client.Username),
|
Name: proto.String(client.Username),
|
||||||
|
|
@ -258,9 +251,6 @@ func (server *Server) updateCodecVersions() {
|
||||||
var winner int32
|
var winner int32
|
||||||
var count int
|
var count int
|
||||||
|
|
||||||
server.cmutex.RLock()
|
|
||||||
defer server.cmutex.RUnlock()
|
|
||||||
|
|
||||||
for _, client := range server.clients {
|
for _, client := range server.clients {
|
||||||
for i := 0; i < len(client.codecs); i++ {
|
for i := 0; i < len(client.codecs); i++ {
|
||||||
codecusers[client.codecs[i]] += 1
|
codecusers[client.codecs[i]] += 1
|
||||||
|
|
@ -314,9 +304,6 @@ func (server *Server) updateCodecVersions() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (server *Server) sendUserList(client *Client) {
|
func (server *Server) sendUserList(client *Client) {
|
||||||
server.cmutex.RLock()
|
|
||||||
defer server.cmutex.RUnlock()
|
|
||||||
|
|
||||||
for _, user := range server.clients {
|
for _, user := range server.clients {
|
||||||
if user.state != StateClientAuthenticated {
|
if user.state != StateClientAuthenticated {
|
||||||
continue
|
continue
|
||||||
|
|
@ -339,9 +326,6 @@ func (server *Server) sendUserList(client *Client) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (server *Server) broadcastProtoMessage(kind uint16, msg interface{}) (err os.Error) {
|
func (server *Server) broadcastProtoMessage(kind uint16, msg interface{}) (err os.Error) {
|
||||||
server.cmutex.RLock()
|
|
||||||
defer server.cmutex.RUnlock()
|
|
||||||
|
|
||||||
for _, client := range server.clients {
|
for _, client := range server.clients {
|
||||||
if client.state != StateClientAuthenticated {
|
if client.state != StateClientAuthenticated {
|
||||||
continue
|
continue
|
||||||
|
|
@ -476,55 +460,45 @@ func (server *Server) ListenUDP() {
|
||||||
} else {
|
} else {
|
||||||
var match *Client
|
var match *Client
|
||||||
plain := make([]byte, nread-4)
|
plain := make([]byte, nread-4)
|
||||||
decrypted := false
|
|
||||||
|
|
||||||
// First, check if any of our clients match the net.UDPAddr...
|
// Determine which client sent the the packet. First, we
|
||||||
server.cmutex.RLock()
|
// check the map 'hpclients' in the server struct. It maps
|
||||||
for _, client := range server.clients {
|
// a hort-post combination to a client.
|
||||||
if client.udpaddr.String() == udpaddr.String() {
|
//
|
||||||
match = client
|
// If we don't find any matches, we look in the 'hclients',
|
||||||
}
|
// which maps a host address to a slice of clients.
|
||||||
}
|
server.hmutex.RLock()
|
||||||
server.cmutex.RUnlock()
|
defer server.hmutex.RUnlock()
|
||||||
|
client, ok := server.hpclients[udpaddr.String()]
|
||||||
// No matching client found. We must try to decrypt...
|
if ok {
|
||||||
if match == nil {
|
|
||||||
server.cmutex.RLock()
|
|
||||||
for _, client := range server.clients {
|
|
||||||
// Try to decrypt.
|
|
||||||
err = client.crypt.Decrypt(buf[0:nread], plain[0:])
|
err = client.crypt.Decrypt(buf[0:nread], plain[0:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Decryption failed. Try another client...
|
log.Panicf("Unable to decrypt incoming packet for client %v (host-port matched)", client)
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decryption succeeded.
|
|
||||||
decrypted = true
|
|
||||||
|
|
||||||
// If we were able to successfully decrpyt, add
|
|
||||||
// the UDPAddr to the Client struct.
|
|
||||||
log.Printf("Client UDP connection established.")
|
|
||||||
client.udpaddr = remote.(*net.UDPAddr)
|
|
||||||
match = client
|
match = client
|
||||||
|
} else {
|
||||||
break
|
host := udpaddr.IP.String()
|
||||||
|
server.hmutex.RLock()
|
||||||
|
hostclients := server.hclients[host]
|
||||||
|
for _, client := range hostclients {
|
||||||
|
err = client.crypt.Decrypt(buf[0:nread], plain[0:])
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
match = client
|
||||||
|
}
|
||||||
}
|
}
|
||||||
server.cmutex.RUnlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// We were not able to find a client that could decrypt the incoming
|
// No client found.
|
||||||
// packet. Log it?
|
|
||||||
if match == nil {
|
if match == nil {
|
||||||
|
log.Printf("No match found for packet. Discarding...")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if !decrypted {
|
if match.udpaddr == nil {
|
||||||
err = match.crypt.Decrypt(buf[0:nread], plain[0:])
|
match.udpaddr = udpaddr
|
||||||
if err != nil {
|
|
||||||
log.Printf("Unable to decrypt from client..")
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
match.udp = true
|
match.udp = true
|
||||||
match.udprecv <- plain
|
match.udprecv <- plain
|
||||||
}
|
}
|
||||||
|
|
@ -553,16 +527,15 @@ func (s *Server) ListenAndMurmur() {
|
||||||
// when we do get a new connection, we spawn
|
// when we do get a new connection, we spawn
|
||||||
// a new Go-routine to handle the client.
|
// a new Go-routine to handle the client.
|
||||||
for {
|
for {
|
||||||
|
|
||||||
// New client connected
|
// New client connected
|
||||||
conn, err := l.Accept()
|
conn, err := l.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("unable to accept()")
|
log.Printf("Unable to accept() new client.")
|
||||||
}
|
}
|
||||||
|
|
||||||
tls, ok := conn.(*tls.Conn)
|
tls, ok := conn.(*tls.Conn)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Printf("Not tls :(")
|
log.Panic("Internal inconsistency error.")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Force the TLS handshake to get going. We'd like
|
// Force the TLS handshake to get going. We'd like
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue