diff --git a/irc/ircconn.go b/irc/ircconn.go index 088909a2..f0fe947a 100644 --- a/irc/ircconn.go +++ b/irc/ircconn.go @@ -5,6 +5,7 @@ package irc import ( "bytes" + "io" "net" "unicode/utf8" @@ -93,21 +94,25 @@ func (cc *IRCStreamConn) Close() (err error) { // IRCWSConn is an IRCConn over a websocket. type IRCWSConn struct { conn *websocket.Conn + buf []byte binary bool } -func NewIRCWSConn(conn *websocket.Conn) IRCWSConn { - binary := conn.Subprotocol() == "binary.ircv3.net" - return IRCWSConn{conn: conn, binary: binary} +func NewIRCWSConn(conn *websocket.Conn) *IRCWSConn { + return &IRCWSConn{ + conn: conn, + binary: conn.Subprotocol() == "binary.ircv3.net", + buf: make([]byte, maxReadQBytes()), + } } -func (wc IRCWSConn) UnderlyingConn() *utils.WrappedConn { +func (wc *IRCWSConn) UnderlyingConn() *utils.WrappedConn { // just assume that the type is OK wConn, _ := wc.conn.UnderlyingConn().(*utils.WrappedConn) return wConn } -func (wc IRCWSConn) WriteLine(buf []byte) (err error) { +func (wc *IRCWSConn) WriteLine(buf []byte) (err error) { buf = bytes.TrimSuffix(buf, crlf) // #1483: if we have websockets at all, then we're enforcing utf8 messageType := websocket.TextMessage @@ -117,7 +122,7 @@ func (wc IRCWSConn) WriteLine(buf []byte) (err error) { return wc.conn.WriteMessage(messageType, buf) } -func (wc IRCWSConn) WriteLines(buffers [][]byte) (err error) { +func (wc *IRCWSConn) WriteLines(buffers [][]byte) (err error) { for _, buf := range buffers { err = wc.WriteLine(buf) if err != nil { @@ -127,20 +132,35 @@ func (wc IRCWSConn) WriteLines(buffers [][]byte) (err error) { return } -func (wc IRCWSConn) ReadLine() (line []byte, err error) { - messageType, line, err := wc.conn.ReadMessage() - if err == nil { +func (wc *IRCWSConn) ReadLine() (line []byte, err error) { + messageType, reader, err := wc.conn.NextReader() + switch err { + case nil: + // OK + case websocket.ErrReadLimit: + return line, ircreader.ErrReadQ + default: + return line, err + } + + n, err := io.ReadFull(reader, wc.buf) + line = wc.buf[:n] + switch err { + case io.ErrUnexpectedEOF, io.EOF: + // these are OK. io.ErrUnexpectedEOF is the good case: + // it means we read the full message and it consumed less than the full wc.buf if messageType == websocket.BinaryMessage && !utf8.Valid(line) { return line, errInvalidUtf8 } return line, nil - } else if err == websocket.ErrReadLimit { + case nil, websocket.ErrReadLimit: + // nil means we filled wc.buf without exhausting the reader: return line, ircreader.ErrReadQ - } else { + default: return line, err } } -func (wc IRCWSConn) Close() (err error) { +func (wc *IRCWSConn) Close() (err error) { return wc.conn.Close() }