1
0
Fork 0
forked from External/ergo

support unix domain sockets

This commit is contained in:
Shivaram Lingamneni 2018-02-01 15:53:49 -05:00
parent bec39ee8cb
commit 2a7f055ef3
5 changed files with 91 additions and 40 deletions

View file

@ -267,19 +267,15 @@ func (server *Server) Run() {
func (server *Server) acceptClient(conn clientConn) {
// check IP address
ipaddr := net.ParseIP(utils.IPString(conn.Conn.RemoteAddr()))
if ipaddr == nil {
conn.Conn.Write([]byte(couldNotParseIPMsg))
conn.Conn.Close()
return
}
isBanned, banMsg := server.checkBans(ipaddr)
if isBanned {
// this might not show up properly on some clients, but our objective here is just to close the connection out before it has a load impact on us
conn.Conn.Write([]byte(fmt.Sprintf(errorMsg, banMsg)))
conn.Conn.Close()
return
ipaddr := utils.AddrToIP(conn.Conn.RemoteAddr())
if ipaddr != nil {
isBanned, banMsg := server.checkBans(ipaddr)
if isBanned {
// this might not show up properly on some clients, but our objective here is just to close the connection out before it has a load impact on us
conn.Conn.Write([]byte(fmt.Sprintf(errorMsg, banMsg)))
conn.Conn.Close()
return
}
}
server.logger.Debug("localconnect-ip", fmt.Sprintf("Client connecting from %v", ipaddr))
@ -336,7 +332,23 @@ func (server *Server) checkBans(ipaddr net.IP) (banned bool, message string) {
// createListener starts the given listeners.
func (server *Server) createListener(addr string, tlsConfig *tls.Config) *ListenerWrapper {
// make listener
listener, err := net.Listen("tcp", addr)
var listener net.Listener
var err error
optional_unix_prefix := "unix:"
optional_prefix_len := len(optional_unix_prefix)
if len(addr) >= optional_prefix_len && strings.ToLower(addr[0:optional_prefix_len]) == optional_unix_prefix {
addr = addr[optional_prefix_len:]
if len(addr) == 0 || addr[0] != '/' {
log.Fatal("Bad unix socket address", addr)
}
}
if len(addr) > 0 && addr[0] == '/' {
// https://stackoverflow.com/a/34881585
os.Remove(addr)
listener, err = net.Listen("unix", addr)
} else {
listener, err = net.Listen("tcp", addr)
}
if err != nil {
log.Fatal(server, "listen error: ", err)
}