Get rid of Client's sender goroutine.

This commit is contained in:
Mikkel Krautz 2011-11-11 21:08:32 +01:00
parent fa3770bffb
commit 875cc89b9e
4 changed files with 62 additions and 111 deletions

128
client.go
View file

@ -6,6 +6,7 @@ package main
import ( import (
"bufio" "bufio"
"bytes"
"encoding/binary" "encoding/binary"
"goprotobuf.googlecode.com/hg/proto" "goprotobuf.googlecode.com/hg/proto"
"grumble/blobstore" "grumble/blobstore"
@ -29,13 +30,10 @@ type Client struct {
udpaddr *net.UDPAddr udpaddr *net.UDPAddr
conn net.Conn conn net.Conn
reader *bufio.Reader reader *bufio.Reader
writer *bufio.Writer
state int state int
server *Server server *Server
msgchan chan *Message
udprecv chan []byte udprecv chan []byte
doneSending chan bool
disconnected bool disconnected bool
@ -157,13 +155,6 @@ func (client *Client) disconnect(kicked bool) {
close(client.clientReady) close(client.clientReady)
} }
// Cleanly shut down the sender goroutine. This should be non-blocking
// since we're writing to a bufio.Writer.
// todo(mkrautz): Check whether that's the case? We do a flush, so maybe not.
client.msgchan <- nil
<-client.doneSending
close(client.msgchan)
client.Printf("Disconnected") client.Printf("Disconnected")
client.conn.Close() client.conn.Close()
} }
@ -186,7 +177,7 @@ func (client *Client) RejectAuth(rejectType mumbleproto.Reject_RejectType, reaso
reasonString = proto.String(reason) reasonString = proto.String(reason)
} }
client.sendProtoMessage(&mumbleproto.Reject{ client.sendMessage(&mumbleproto.Reject{
Type: mumbleproto.NewReject_RejectType(rejectType), Type: mumbleproto.NewReject_RejectType(rejectType),
Reason: reasonString, Reason: reasonString,
}) })
@ -228,21 +219,6 @@ func (client *Client) readProtoMessage() (msg *Message, err error) {
return return
} }
// Send a protobuf-encoded message
func (c *Client) sendProtoMessage(msg interface{}) (err error) {
d, err := proto.Marshal(msg)
if err != nil {
return
}
c.msgchan <- &Message{
buf: d,
kind: mumbleproto.MessageType(msg),
}
return
}
// Send permission denied by type // Send permission denied by type
func (c *Client) sendPermissionDeniedType(denyType mumbleproto.PermissionDenied_DenyType) { func (c *Client) sendPermissionDeniedType(denyType mumbleproto.PermissionDenied_DenyType) {
c.sendPermissionDeniedTypeUser(denyType, nil) c.sendPermissionDeniedTypeUser(denyType, nil)
@ -256,31 +232,25 @@ func (c *Client) sendPermissionDeniedTypeUser(denyType mumbleproto.PermissionDen
if user != nil { if user != nil {
pd.Session = proto.Uint32(uint32(user.Session)) pd.Session = proto.Uint32(uint32(user.Session))
} }
d, err := proto.Marshal(pd) err := c.sendMessage(pd)
if err != nil { if err != nil {
c.Panicf("%v", err) c.Panicf("%v", err.Error())
return return
} }
c.msgchan <- &Message{
buf: d,
kind: mumbleproto.MessagePermissionDenied,
}
} }
// Send permission denied by who, what, where // Send permission denied by who, what, where
func (c *Client) sendPermissionDenied(who *Client, where *Channel, what Permission) { func (c *Client) sendPermissionDenied(who *Client, where *Channel, what Permission) {
d, err := proto.Marshal(&mumbleproto.PermissionDenied{ pd := &mumbleproto.PermissionDenied{
Permission: proto.Uint32(uint32(what)), Permission: proto.Uint32(uint32(what)),
ChannelId: proto.Uint32(uint32(where.Id)), ChannelId: proto.Uint32(uint32(where.Id)),
Session: proto.Uint32(who.Session), Session: proto.Uint32(who.Session),
Type: mumbleproto.NewPermissionDenied_DenyType(mumbleproto.PermissionDenied_Permission), Type: mumbleproto.NewPermissionDenied_DenyType(mumbleproto.PermissionDenied_Permission),
})
if err != nil {
c.Panicf(err.Error())
} }
c.msgchan <- &Message{ err := c.sendMessage(pd)
buf: d, if err != nil {
kind: mumbleproto.MessagePermissionDenied, c.Panicf("%v", err.Error())
return
} }
} }
@ -357,9 +327,7 @@ func (client *Client) sendUdp(msg *Message) {
client.Printf("Sent UDP!") client.Printf("Sent UDP!")
client.server.udpsend <- msg client.server.udpsend <- msg
} else { } else {
client.Printf("Sent TCP!") client.sendMessage(msg.buf)
msg.kind = mumbleproto.MessageUDPTunnel
client.msgchan <- msg
} }
} }
@ -369,27 +337,38 @@ func (client *Client) sendUdp(msg *Message) {
// This method should only be called from within the client's own // This method should only be called from within the client's own
// sender goroutine, since it serializes access to the underlying // sender goroutine, since it serializes access to the underlying
// buffered writer. // buffered writer.
func (client *Client) sendMessage(msg *Message) error { func (client *Client) sendMessage(msg interface{}) error {
// Write message kind buf := new(bytes.Buffer)
err := binary.Write(client.writer, binary.BigEndian, msg.kind) var (
kind uint16
msgData []byte
err error
)
kind = mumbleproto.MessageType(msg)
if kind == mumbleproto.MessageUDPTunnel {
msgData = msg.([]byte)
} else {
msgData, err = proto.Marshal(msg)
if err != nil {
return err
}
}
err = binary.Write(buf, binary.BigEndian, kind)
if err != nil {
return err
}
err = binary.Write(buf, binary.BigEndian, uint32(len(msgData)))
if err != nil {
return err
}
_, err = buf.Write(msgData)
if err != nil { if err != nil {
return err return err
} }
// Message length _, err = client.conn.Write(buf.Bytes())
err = binary.Write(client.writer, binary.BigEndian, uint32(len(msg.buf)))
if err != nil {
return err
}
// Message buffer itself
_, err = client.writer.Write(msg.buf)
if err != nil {
return err
}
// Flush it, no need to keep it in the buffer for any longer.
err = client.writer.Flush()
if err != nil { if err != nil {
return err return err
} }
@ -397,31 +376,6 @@ func (client *Client) sendMessage(msg *Message) error {
return nil return nil
} }
// Sender Goroutine. The sender goroutine will initiate shutdown
// if it receives a nil Message.
//
// On shutdown, it will send a true boolean value on the client's
// doneSending channel. This allows the client to send all the messages
// that remain in it's buffer when the server has to force a disconnect.
func (client *Client) sender() {
defer func() {
client.doneSending <- true
}()
for msg := range client.msgchan {
if msg == nil {
return
}
err := client.sendMessage(msg)
if err != nil {
// fixme(mkrautz): This is a deadlock waiting to happen.
client.Panicf("Unable to send message to client")
return
}
}
}
// Receiver Goroutine // Receiver Goroutine
func (client *Client) receiver() { func (client *Client) receiver() {
for { for {
@ -480,7 +434,7 @@ func (client *Client) receiver() {
// information we must send it our version information so it knows // information we must send it our version information so it knows
// what version of the protocol it should speak. // what version of the protocol it should speak.
if client.state == StateClientConnected { if client.state == StateClientConnected {
client.sendProtoMessage(&mumbleproto.Version{ client.sendMessage(&mumbleproto.Version{
Version: proto.Uint32(0x10203), Version: proto.Uint32(0x10203),
Release: proto.String("Grumble"), Release: proto.String("Grumble"),
}) })
@ -570,7 +524,7 @@ func (client *Client) sendChannelTree(channel *Channel) {
} }
chanstate.Links = links chanstate.Links = links
err := client.sendProtoMessage(chanstate) err := client.sendMessage(chanstate)
if err != nil { if err != nil {
client.Panicf("%v", err) client.Panicf("%v", err)
} }
@ -588,7 +542,7 @@ func (client *Client) cryptResync() {
if requestElapsed > 5 { if requestElapsed > 5 {
client.lastResync = time.Seconds() client.lastResync = time.Seconds()
cryptsetup := &mumbleproto.CryptSetup{} cryptsetup := &mumbleproto.CryptSetup{}
err := client.sendProtoMessage(cryptsetup) err := client.sendMessage(cryptsetup)
if err != nil { if err != nil {
client.Panicf("%v", err) client.Panicf("%v", err)
} }

View file

@ -57,7 +57,7 @@ func (server *Server) handleCryptSetup(client *Client, msg *Message) {
if copy(cs.ClientNonce, client.crypt.EncryptIV[0:]) != aes.BlockSize { if copy(cs.ClientNonce, client.crypt.EncryptIV[0:]) != aes.BlockSize {
return return
} }
client.sendProtoMessage(cs) client.sendMessage(cs)
} else { } else {
client.Printf("Received client nonce") client.Printf("Received client nonce")
if len(cs.ClientNonce) != aes.BlockSize { if len(cs.ClientNonce) != aes.BlockSize {
@ -113,7 +113,7 @@ func (server *Server) handlePingMessage(client *Client, msg *Message) {
client.TcpPackets = *ping.TcpPackets client.TcpPackets = *ping.TcpPackets
} }
client.sendProtoMessage(&mumbleproto.Ping{ client.sendMessage(&mumbleproto.Ping{
Timestamp: ping.Timestamp, Timestamp: ping.Timestamp,
Good: proto.Uint32(uint32(client.crypt.Good)), Good: proto.Uint32(uint32(client.crypt.Good)),
Late: proto.Uint32(uint32(client.crypt.Late)), Late: proto.Uint32(uint32(client.crypt.Late)),
@ -917,7 +917,7 @@ func (server *Server) handleBanListMessage(client *Client, msg *Message) {
entry.Duration = proto.Uint32(ban.Duration) entry.Duration = proto.Uint32(ban.Duration)
banlist.Bans = append(banlist.Bans, entry) banlist.Bans = append(banlist.Bans, entry)
} }
if err := client.sendProtoMessage(banlist); err != nil { if err := client.sendMessage(banlist); err != nil {
client.Panic("Unable to send BanList") client.Panic("Unable to send BanList")
} }
} else { } else {
@ -1017,7 +1017,7 @@ func (server *Server) handleTextMessage(client *Client, msg *Message) {
delete(clients, client.Session) delete(clients, client.Session)
for _, target := range clients { for _, target := range clients {
target.sendProtoMessage(&mumbleproto.TextMessage{ target.sendMessage(&mumbleproto.TextMessage{
Actor: proto.Uint32(client.Session), Actor: proto.Uint32(client.Session),
Message: txtmsg.Message, Message: txtmsg.Message,
}) })
@ -1147,7 +1147,7 @@ func (server *Server) handleAclMessage(client *Client, msg *Message) {
reply.Groups = append(reply.Groups, mpgroup) reply.Groups = append(reply.Groups, mpgroup)
} }
if err := client.sendProtoMessage(reply); err != nil { if err := client.sendMessage(reply); err != nil {
client.Panic(err) client.Panic(err)
return return
} }
@ -1164,7 +1164,7 @@ func (server *Server) handleAclMessage(client *Client, msg *Message) {
queryusers.Names = append(queryusers.Names, user.Name) queryusers.Names = append(queryusers.Names, user.Name)
} }
if len(queryusers.Ids) > 0 { if len(queryusers.Ids) > 0 {
client.sendProtoMessage(queryusers) client.sendMessage(queryusers)
} }
// Set new groups and ACLs // Set new groups and ACLs
@ -1272,7 +1272,7 @@ func (server *Server) handleQueryUsers(client *Client, msg *Message) {
} }
} }
if err := client.sendProtoMessage(reply); err != nil { if err := client.sendMessage(reply); err != nil {
client.Panic(err) client.Panic(err)
return return
} }
@ -1377,7 +1377,7 @@ func (server *Server) handleUserStatsMessage(client *Client, msg *Message) {
// fixme(mkrautz): we don't do bandwidth tracking yet // fixme(mkrautz): we don't do bandwidth tracking yet
if err := client.sendProtoMessage(stats); err != nil { if err := client.sendMessage(stats); err != nil {
client.Panic(err) client.Panic(err)
return return
} }
@ -1427,7 +1427,7 @@ func (server *Server) handleRequestBlob(client *Client, msg *Message) {
userstate.Reset() userstate.Reset()
userstate.Session = proto.Uint32(uint32(target.Session)) userstate.Session = proto.Uint32(uint32(target.Session))
userstate.Texture = buf userstate.Texture = buf
if err := client.sendProtoMessage(userstate); err != nil { if err := client.sendMessage(userstate); err != nil {
client.Panic(err) client.Panic(err)
return return
} }
@ -1452,7 +1452,7 @@ func (server *Server) handleRequestBlob(client *Client, msg *Message) {
userstate.Reset() userstate.Reset()
userstate.Session = proto.Uint32(uint32(target.Session)) userstate.Session = proto.Uint32(uint32(target.Session))
userstate.Comment = proto.String(string(buf)) userstate.Comment = proto.String(string(buf))
if err := client.sendProtoMessage(userstate); err != nil { if err := client.sendMessage(userstate); err != nil {
client.Panic(err) client.Panic(err)
return return
} }
@ -1476,7 +1476,7 @@ func (server *Server) handleRequestBlob(client *Client, msg *Message) {
} }
chanstate.ChannelId = proto.Uint32(uint32(channel.Id)) chanstate.ChannelId = proto.Uint32(uint32(channel.Id))
chanstate.Description = proto.String(string(buf)) chanstate.Description = proto.String(string(buf))
if err := client.sendProtoMessage(chanstate); err != nil { if err := client.sendMessage(chanstate); err != nil {
client.Panic(err) client.Panic(err)
return return
} }
@ -1512,7 +1512,7 @@ func (server *Server) handleUserList(client *Client, msg *Message) {
Name: proto.String(user.Name), Name: proto.String(user.Name),
}) })
} }
if err := client.sendProtoMessage(userlist); err != nil { if err := client.sendMessage(userlist); err != nil {
client.Panic(err) client.Panic(err)
return return
} }

View file

@ -45,6 +45,7 @@ func MessageType(msg interface{}) uint16 {
case *Version: case *Version:
return MessageVersion return MessageVersion
case *UDPTunnel: case *UDPTunnel:
case []byte:
return MessageUDPTunnel return MessageUDPTunnel
case *Authenticate: case *Authenticate:
return MessageAuthenticate return MessageAuthenticate

View file

@ -260,10 +260,9 @@ func (server *Server) NewClient(conn net.Conn) (err error) {
client.server = server client.server = server
client.conn = conn client.conn = conn
client.reader = bufio.NewReader(client.conn) client.reader = bufio.NewReader(client.conn)
client.writer = bufio.NewWriter(client.conn)
client.state = StateClientConnected client.state = StateClientConnected
client.msgchan = make(chan *Message)
client.udprecv = make(chan []byte) client.udprecv = make(chan []byte)
client.user = nil client.user = nil
@ -271,9 +270,6 @@ func (server *Server) NewClient(conn net.Conn) (err error) {
go client.receiver() go client.receiver()
go client.udpreceiver() go client.udpreceiver()
client.doneSending = make(chan bool)
go client.sender()
return return
} }
@ -505,7 +501,7 @@ func (server *Server) handleAuthenticate(client *Client, msg *Message) {
// Send CryptState information to the client so it can establish an UDP connection, // Send CryptState information to the client so it can establish an UDP connection,
// if it wishes. // if it wishes.
client.lastResync = time.Seconds() client.lastResync = time.Seconds()
err = client.sendProtoMessage(&mumbleproto.CryptSetup{ err = client.sendMessage(&mumbleproto.CryptSetup{
Key: client.crypt.RawKey[0:], Key: client.crypt.RawKey[0:],
ClientNonce: client.crypt.DecryptIV[0:], ClientNonce: client.crypt.DecryptIV[0:],
ServerNonce: client.crypt.EncryptIV[0:], ServerNonce: client.crypt.EncryptIV[0:],
@ -629,12 +625,12 @@ func (server *Server) finishAuthenticate(client *Client) {
perm.ClearCacheBit() perm.ClearCacheBit()
sync.Permissions = proto.Uint64(uint64(perm)) sync.Permissions = proto.Uint64(uint64(perm))
} }
if err := client.sendProtoMessage(sync); err != nil { if err := client.sendMessage(sync); err != nil {
client.Panicf("%v", err) client.Panicf("%v", err)
return return
} }
err := client.sendProtoMessage(&mumbleproto.ServerConfig{ err := client.sendMessage(&mumbleproto.ServerConfig{
AllowHtml: proto.Bool(server.cfg.BoolValue("AllowHTML")), AllowHtml: proto.Bool(server.cfg.BoolValue("AllowHTML")),
MessageLength: proto.Uint32(server.cfg.Uint32Value("MaxTextMessageLength")), MessageLength: proto.Uint32(server.cfg.Uint32Value("MaxTextMessageLength")),
ImageMessageLength: proto.Uint32(server.cfg.Uint32Value("MaxImageMessageLength")), ImageMessageLength: proto.Uint32(server.cfg.Uint32Value("MaxImageMessageLength")),
@ -780,7 +776,7 @@ func (server *Server) sendUserList(client *Client) {
userstate.PluginIdentity = proto.String(connectedClient.PluginIdentity) userstate.PluginIdentity = proto.String(connectedClient.PluginIdentity)
} }
err := client.sendProtoMessage(userstate) err := client.sendMessage(userstate)
if err != nil { if err != nil {
// Server panic? // Server panic?
continue continue
@ -800,7 +796,7 @@ func (server *Server) sendClientPermissions(client *Client, channel *Channel) {
perm := server.aclcache.GetPermission(client, channel) perm := server.aclcache.GetPermission(client, channel)
// fixme(mkrautz): Cache which permissions we've already sent. // fixme(mkrautz): Cache which permissions we've already sent.
client.sendProtoMessage(&mumbleproto.PermissionQuery{ client.sendMessage(&mumbleproto.PermissionQuery{
ChannelId: proto.Uint32(uint32(channel.Id)), ChannelId: proto.Uint32(uint32(channel.Id)),
Permissions: proto.Uint32(uint32(perm)), Permissions: proto.Uint32(uint32(perm)),
}) })
@ -816,7 +812,7 @@ func (server *Server) broadcastProtoMessageWithPredicate(msg interface{}, client
if client.state < StateClientAuthenticated { if client.state < StateClientAuthenticated {
continue continue
} }
err := client.sendProtoMessage(msg) err := client.sendMessage(msg)
if err != nil { if err != nil {
return err return err
} }