1
0
Fork 0
forked from External/grumble

Merge pull request #18: WebSocket support

This commit is contained in:
Mikkel Krautz 2018-02-09 22:21:52 +01:00 committed by GitHub
commit cb1d3db727
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 257 additions and 39 deletions

View file

@ -1328,7 +1328,9 @@ func (server *Server) handleUserStatsMessage(client *Client, msg *Message) {
stats.Session = proto.Uint32(target.Session()) stats.Session = proto.Uint32(target.Session())
if details { if details {
if tlsconn := target.conn.(*tls.Conn); tlsconn != nil { // Only consider client certificates for direct connections, not WebSocket connections.
// We do not support TLS-level client certificates for WebSocket client.
if tlsconn, ok := target.conn.(*tls.Conn); ok {
state := tlsconn.ConnectionState() state := tlsconn.ConnectionState()
for i := len(state.PeerCertificates) - 1; i >= 0; i-- { for i := len(state.PeerCertificates) - 1; i >= 0; i-- {
stats.Certificates = append(stats.Certificates, state.PeerCertificates[i].Raw) stats.Certificates = append(stats.Certificates, state.PeerCertificates[i].Raw)

View file

@ -7,6 +7,7 @@ package main
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"crypto/rand" "crypto/rand"
"crypto/sha1" "crypto/sha1"
"crypto/tls" "crypto/tls"
@ -25,7 +26,9 @@ import (
"mumble.info/grumble/pkg/mumbleproto" "mumble.info/grumble/pkg/mumbleproto"
"mumble.info/grumble/pkg/serverconf" "mumble.info/grumble/pkg/serverconf"
"mumble.info/grumble/pkg/sessionpool" "mumble.info/grumble/pkg/sessionpool"
"mumble.info/grumble/pkg/web"
"net" "net"
"net/http"
"path/filepath" "path/filepath"
"strings" "strings"
"sync" "sync"
@ -34,6 +37,7 @@ import (
// The default port a Murmur server listens on // The default port a Murmur server listens on
const DefaultPort = 64738 const DefaultPort = 64738
const DefaultWebPort = 443
const UDPPacketSize = 1024 const UDPPacketSize = 1024
const LogOpsBeforeSync = 100 const LogOpsBeforeSync = 100
@ -57,13 +61,16 @@ type KeyValuePair struct {
type Server struct { type Server struct {
Id int64 Id int64
tcpl *net.TCPListener tcpl *net.TCPListener
tlsl net.Listener tlsl net.Listener
udpconn *net.UDPConn udpconn *net.UDPConn
tlscfg *tls.Config tlscfg *tls.Config
bye chan bool webwsl *web.Listener
netwg sync.WaitGroup webtlscfg *tls.Config
running bool webhttp *http.Server
bye chan bool
netwg sync.WaitGroup
running bool
incoming chan *Message incoming chan *Message
voicebroadcast chan *VoiceBroadcast voicebroadcast chan *VoiceBroadcast
@ -256,27 +263,30 @@ func (server *Server) handleIncomingClient(conn net.Conn) (err error) {
client.user = nil client.user = nil
// Extract user's cert hash // Extract user's cert hash
tlsconn := client.conn.(*tls.Conn) // Only consider client certificates for direct connections, not WebSocket connections.
err = tlsconn.Handshake() // We do not support TLS-level client certificates for WebSocket client.
if err != nil { if tlsconn, ok := client.conn.(*tls.Conn); ok {
client.Printf("TLS handshake failed: %v", err) err = tlsconn.Handshake()
client.Disconnect() if err != nil {
return client.Printf("TLS handshake failed: %v", err)
} client.Disconnect()
return
}
state := tlsconn.ConnectionState() state := tlsconn.ConnectionState()
if len(state.PeerCertificates) > 0 { if len(state.PeerCertificates) > 0 {
hash := sha1.New() hash := sha1.New()
hash.Write(state.PeerCertificates[0].Raw) hash.Write(state.PeerCertificates[0].Raw)
sum := hash.Sum(nil) sum := hash.Sum(nil)
client.certHash = hex.EncodeToString(sum) client.certHash = hex.EncodeToString(sum)
} }
// Check whether the client's cert hash is banned // Check whether the client's cert hash is banned
if server.IsCertHashBanned(client.CertHash()) { if server.IsCertHashBanned(client.CertHash()) {
client.Printf("Certificate hash is banned") client.Printf("Certificate hash is banned")
client.Disconnect() client.Disconnect()
return return
}
} }
// Launch network readers // Launch network readers
@ -1090,7 +1100,7 @@ func (s *Server) RegisterClient(client *Client) (uid uint32, err error) {
} }
// Grumble can only register users with certificates. // Grumble can only register users with certificates.
if client.HasCertificate() { if !client.HasCertificate() {
return 0, errors.New("no cert hash") return 0, errors.New("no cert hash")
} }
@ -1258,12 +1268,12 @@ func (server *Server) FilterText(text string) (filtered string, err error) {
} }
// The accept loop of the server. // The accept loop of the server.
func (server *Server) acceptLoop() { func (server *Server) acceptLoop(listener net.Listener) {
defer server.netwg.Done() defer server.netwg.Done()
for { for {
// New client connected // New client connected
conn, err := server.tlsl.Accept() conn, err := listener.Accept()
if err != nil { if err != nil {
if isTimeout(err) { if isTimeout(err) {
continue continue
@ -1334,8 +1344,8 @@ func (server *Server) cleanPerLaunchData() {
server.clientAuthenticated = nil server.clientAuthenticated = nil
} }
// Returns the port the server will listen on when it is // Returns the port the native server will listen on when it is
// started. Returns 0 on failure. // started.
func (server *Server) Port() int { func (server *Server) Port() int {
port := server.cfg.IntValue("Port") port := server.cfg.IntValue("Port")
if port == 0 { if port == 0 {
@ -1344,7 +1354,17 @@ func (server *Server) Port() int {
return port return port
} }
// Returns the port the server is currently listning // Returns the port the web server will listen on when it is
// started.
func (server *Server) WebPort() int {
port := server.cfg.IntValue("WebPort")
if port == 0 {
return DefaultWebPort + int(server.Id) - 1
}
return port
}
// Returns the port the native server is currently listening
// on. If called when the server is not running, // on. If called when the server is not running,
// this function returns -1. // this function returns -1.
func (server *Server) CurrentPort() int { func (server *Server) CurrentPort() int {
@ -1374,6 +1394,7 @@ func (server *Server) Start() (err error) {
host := server.HostAddress() host := server.HostAddress()
port := server.Port() port := server.Port()
webport := server.WebPort()
// Setup our UDP listener // Setup our UDP listener
server.udpconn, err = net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP(host), Port: port}) server.udpconn, err = net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP(host), Port: port})
@ -1412,7 +1433,37 @@ func (server *Server) Start() (err error) {
} }
server.tlsl = tls.NewListener(server.tcpl, server.tlscfg) server.tlsl = tls.NewListener(server.tcpl, server.tlscfg)
server.Printf("Started: listening on %v", server.tcpl.Addr()) // Create HTTP server and WebSocket "listener"
webaddr := &net.TCPAddr{IP: net.ParseIP(host), Port: webport}
server.webtlscfg = &tls.Config{
Certificates: []tls.Certificate{cert},
ClientAuth: tls.NoClientCert,
NextProtos: []string{"http/1.1"},
}
server.webwsl = web.NewListener(webaddr, server.Logger)
mux := http.NewServeMux()
mux.Handle("/", server.webwsl)
server.webhttp = &http.Server{
Addr: webaddr.String(),
Handler: mux,
TLSConfig: server.webtlscfg,
ErrorLog: server.Logger,
// Set sensible timeouts, in case no reverse proxy is in front of Grumble.
// Non-conforming (or malicious) clients may otherwise block indefinitely and cause
// file descriptors (or handles, depending on your OS) to leak and/or be exhausted
ReadTimeout: 5 * time.Second,
WriteTimeout: 10 * time.Second,
IdleTimeout: 2 * time.Minute,
}
go func() {
err := server.webhttp.ListenAndServeTLS("", "")
if err != http.ErrServerClosed {
server.Fatalf("Fatal HTTP server error: %v", err)
}
}()
server.Printf("Started: listening on %v and %v", server.tcpl.Addr(), server.webwsl.Addr())
server.running = true server.running = true
// Open a fresh freezer log // Open a fresh freezer log
@ -1428,16 +1479,17 @@ func (server *Server) Start() (err error) {
// Launch the event handler goroutine // Launch the event handler goroutine
go server.handlerLoop() go server.handlerLoop()
// Add the two network receiver goroutines to the net waitgroup // Add the three network receiver goroutines to the net waitgroup
// and launch them. // and launch them.
// //
// We use the waitgroup to provide a blocking Stop() method // We use the waitgroup to provide a blocking Stop() method
// for the servers. Each network goroutine defers a call to // for the servers. Each network goroutine defers a call to
// netwg.Done(). In the Stop() we close all the connections // netwg.Done(). In the Stop() we close all the connections
// and call netwg.Wait() to wait for the goroutines to end. // and call netwg.Wait() to wait for the goroutines to end.
server.netwg.Add(2) server.netwg.Add(3)
go server.udpListenLoop() go server.udpListenLoop()
go server.acceptLoop() go server.acceptLoop(server.tlsl)
go server.acceptLoop(server.webwsl)
// Schedule a server registration update (if needed) // Schedule a server registration update (if needed)
go func() { go func() {
@ -1461,7 +1513,21 @@ func (server *Server) Stop() (err error) {
client.Disconnect() client.Disconnect()
} }
// Close the TLS listener and the TCP listener // Wait for the HTTP server to shutdown gracefully
// A client could theoretically block the server from ever stopping by
// never letting the HTTP connection go idle, so we give 15 seconds of grace time.
// This does not apply to opened WebSockets, which were forcibly closed when
// all clients were disconnected.
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(15*time.Second))
err = server.webhttp.Shutdown(ctx)
cancel()
if err == context.DeadlineExceeded {
server.Println("Forcibly shutdown HTTP server while stopping")
} else if err != nil {
return err
}
// Close the listeners
err = server.tlsl.Close() err = server.tlsl.Close()
if err != nil { if err != nil {
return err return err
@ -1470,6 +1536,10 @@ func (server *Server) Stop() (err error) {
if err != nil { if err != nil {
return err return err
} }
err = server.webwsl.Close()
if err != nil {
return err
}
// Close the UDP connection // Close the UDP connection
err = server.udpconn.Close() err = server.udpconn.Close()
@ -1485,7 +1555,7 @@ func (server *Server) Stop() (err error) {
server.Fatal(err) server.Fatal(err)
} }
// Wait for the two network receiver // Wait for the three network receiver
// goroutines end. // goroutines end.
server.netwg.Wait() server.netwg.Wait()

67
pkg/web/websocket.go Normal file
View file

@ -0,0 +1,67 @@
// Copyright (c) 2018 The Grumble Authors
// The use of this source code is governed by a BSD-style
// license that can be found in the LICENSE-file.
package web
import (
"bytes"
"io"
"net"
"time"
"github.com/gorilla/websocket"
)
type conn struct {
ws *websocket.Conn
msgbuf bytes.Buffer
}
func (c *conn) Read(b []byte) (n int, err error) {
if c.msgbuf.Len() == 0 {
_, r, err := c.ws.NextReader()
if err != nil {
if _, ok := err.(*websocket.CloseError); ok {
return 0, io.EOF
}
return 0, err
}
if _, err := c.msgbuf.ReadFrom(r); err != nil {
return 0, err
}
}
// Impossible to read over message boundaries - will generate EOF
return c.msgbuf.Read(b)
}
func (c *conn) Write(b []byte) (n int, err error) {
return len(b), c.ws.WriteMessage(websocket.BinaryMessage, b)
}
func (c *conn) Close() error {
return c.ws.Close()
}
func (c *conn) LocalAddr() net.Addr {
return c.ws.LocalAddr()
}
func (c *conn) RemoteAddr() net.Addr {
return c.ws.RemoteAddr()
}
func (c *conn) SetDeadline(t time.Time) (err error) {
if err = c.ws.SetReadDeadline(t); err != nil {
return err
}
return c.ws.SetWriteDeadline(t)
}
func (c *conn) SetReadDeadline(t time.Time) error {
return c.ws.SetReadDeadline(t)
}
func (c *conn) SetWriteDeadline(t time.Time) error {
return c.ws.SetWriteDeadline(t)
}

79
pkg/web/wslisten.go Normal file
View file

@ -0,0 +1,79 @@
// Copyright (c) 2018 The Grumble Authors
// The use of this source code is governed by a BSD-style
// license that can be found in the LICENSE-file.
package web
import (
"fmt"
"log"
"net"
"net/http"
"sync/atomic"
"time"
"github.com/gorilla/websocket"
)
var upgrader = websocket.Upgrader{
HandshakeTimeout: 20 * time.Second,
Subprotocols: []string{"mumble", "binary"},
CheckOrigin: func(r *http.Request) bool {
return true
},
}
type Listener struct {
sockets chan *conn
done chan struct{}
addr net.Addr
closed int32
logger *log.Logger
}
func NewListener(laddr net.Addr, logger *log.Logger) *Listener {
return &Listener{
sockets: make(chan *conn),
done: make(chan struct{}),
addr: laddr,
logger: logger,
}
}
func (l *Listener) Accept() (net.Conn, error) {
if atomic.LoadInt32(&l.closed) != 0 {
return nil, fmt.Errorf("accept ws %v: use of closed websocket listener", l.addr)
}
select {
case ws := <-l.sockets:
return ws, nil
case <-l.done:
return nil, fmt.Errorf("accept ws %v: use of closed websocket listener", l.addr)
}
}
func (l *Listener) Close() error {
if !atomic.CompareAndSwapInt32(&l.closed, 0, 1) {
return fmt.Errorf("close ws %v: use of closed websocket listener", l.addr)
}
close(l.done)
return nil
}
func (l *Listener) Addr() net.Addr {
return l.addr
}
func (l *Listener) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if atomic.LoadInt32(&l.closed) != 0 {
http.Error(w, http.StatusText(http.StatusServiceUnavailable), http.StatusServiceUnavailable)
return
}
l.logger.Printf("Upgrading web connection from: %v", r.RemoteAddr)
ws, err := upgrader.Upgrade(w, r, nil)
if err != nil {
l.logger.Printf("Failed upgrade: %v", err)
return
}
l.sockets <- &conn{ws: ws}
}