diff --git a/irc/client.go b/irc/client.go index 7058c8e1..ab096af5 100644 --- a/irc/client.go +++ b/irc/client.go @@ -27,7 +27,6 @@ type Client struct { phase Phase quitTimer *time.Timer realname string - replies chan Reply server *Server socket *Socket username string @@ -43,14 +42,12 @@ func NewClient(server *Server, conn net.Conn) *Client { friends: make(map[*Client]uint), hostname: AddrLookupHostname(conn.RemoteAddr()), phase: server.InitPhase(), - replies: make(chan Reply), server: server, socket: NewSocket(conn), } client.loginTimer = time.AfterFunc(LOGIN_TIMEOUT, client.connectionClosed) go client.readCommands() - go client.writeReplies() return client } @@ -72,16 +69,6 @@ func (client *Client) readCommands() { } } -func (client *Client) writeReplies() { - for reply := range client.replies { - if DEBUG_CLIENT { - log.Printf("%s ← %s", client, reply) - } - - client.socket.Write(reply.Format(client)) - } -} - func (client *Client) Touch() { client.atime = time.Now() @@ -133,18 +120,26 @@ func (client *Client) Destroy() { log.Printf("%s destroy", client) } + // clean up self + client.socket.Close() client.loginTimer.Stop() - if client.idleTimer != nil { client.idleTimer.Stop() } - if client.quitTimer != nil { client.quitTimer.Stop() } + // clean up channels + + for channel := range client.channels { + channel.Quit(client) + } + + // clean up server + client.server.clients.Remove(client) if DEBUG_CLIENT { @@ -153,13 +148,7 @@ func (client *Client) Destroy() { } func (client *Client) Reply(reply Reply) { - if client.replies == nil { - if DEBUG_CLIENT { - log.Printf("%s dropped %s", client, reply) - } - return - } - client.replies <- reply + client.socket.Write(reply.Format(client)...) } func (client *Client) IdleTime() time.Duration { @@ -250,10 +239,6 @@ func (client *Client) Quit(message string) { } } - for channel := range client.channels { - channel.Quit(client) - } - client.Reply(RplError(client.server, client)) client.Destroy() } diff --git a/irc/server.go b/irc/server.go index 80188e9f..a876bd19 100644 --- a/irc/server.go +++ b/irc/server.go @@ -344,9 +344,7 @@ func (m *UserCommand) HandleServer(s *Server) { } func (msg *QuitCommand) HandleServer(server *Server) { - client := msg.Client() - client.Quit(msg.message) - server.clients.Remove(client) + msg.Client().Quit(msg.message) } func (m *JoinCommand) HandleServer(s *Server) { diff --git a/irc/socket.go b/irc/socket.go index de7ed770..113710a1 100644 --- a/irc/socket.go +++ b/irc/socket.go @@ -7,8 +7,15 @@ import ( "strings" ) +const ( + R = '→' + W = '←' +) + type Socket struct { + closed bool conn net.Conn + done chan bool reader *bufio.Reader receive chan string send chan string @@ -18,6 +25,7 @@ type Socket struct { func NewSocket(conn net.Conn) *Socket { socket := &Socket{ conn: conn, + done: make(chan bool), reader: bufio.NewReader(conn), receive: make(chan string), send: make(chan string), @@ -35,14 +43,20 @@ func (socket *Socket) String() string { } func (socket *Socket) Close() { - socket.conn.Close() + if socket.closed { + return + } + + socket.closed = true + socket.done <- true + close(socket.done) } func (socket *Socket) Read() <-chan string { return socket.receive } -func (socket *Socket) Write(lines []string) { +func (socket *Socket) Write(lines ...string) { for _, line := range lines { socket.send <- line } @@ -52,10 +66,7 @@ func (socket *Socket) Write(lines []string) { func (socket *Socket) readLines() { for { line, err := socket.reader.ReadString('\n') - if err != nil { - if DEBUG_NET { - log.Printf("%s → error: %s", socket, err) - } + if socket.isError(err, R) { break } @@ -69,31 +80,51 @@ func (socket *Socket) readLines() { socket.receive <- line } + close(socket.receive) -} - -func (socket *Socket) writeLines() { - for line := range socket.send { - if DEBUG_NET { - log.Printf("%s ← %s", socket, line) - } - if _, err := socket.writer.WriteString(line); socket.maybeLogWriteError(err) { - break - } - if _, err := socket.writer.WriteString(CRLF); socket.maybeLogWriteError(err) { - break - } - - if err := socket.writer.Flush(); socket.maybeLogWriteError(err) { - break - } + if DEBUG_NET { + log.Printf("%s closed", socket) } } -func (socket *Socket) maybeLogWriteError(err error) bool { +func (socket *Socket) writeLines() { + done := false + for !done { + select { + case line := <-socket.send: + if _, err := socket.writer.WriteString(line); socket.isError(err, W) { + break + } + if _, err := socket.writer.WriteString(CRLF); socket.isError(err, W) { + break + } + + if err := socket.writer.Flush(); socket.isError(err, W) { + break + } + if DEBUG_NET { + log.Printf("%s ← %s", socket, line) + } + + case done = <-socket.done: + continue + } + } + + if DEBUG_NET { + log.Printf("%s closing", socket) + } + socket.conn.Close() + + for _ = range socket.send { + // discard lines + } +} + +func (socket *Socket) isError(err error, dir rune) bool { if err != nil { if DEBUG_NET { - log.Printf("%s ← error: %s", socket, err) + log.Printf("%s %c error: %s", socket, dir, err) } return true }