scripting API for IP bans

See discussion on #68.
This commit is contained in:
Shivaram Lingamneni 2020-09-14 04:28:12 -04:00
parent a742ef9639
commit 1a98a37a75
10 changed files with 271 additions and 105 deletions

View file

@ -4,13 +4,11 @@
package irc
import (
"bufio"
"encoding/json"
"fmt"
"io"
"os/exec"
"syscall"
"time"
"net"
"github.com/oragono/oragono/irc/utils"
)
// JSON-serializable input and output types for the script
@ -27,84 +25,77 @@ type AuthScriptOutput struct {
Error string `json:"error"`
}
// internal tupling of output and error for passing over a channel
type authScriptResponse struct {
output AuthScriptOutput
err error
}
func CheckAuthScript(sem utils.Semaphore, config ScriptConfig, input AuthScriptInput) (output AuthScriptOutput, err error) {
if sem != nil {
sem.Acquire()
defer sem.Release()
}
func CheckAuthScript(config AuthScriptConfig, input AuthScriptInput) (output AuthScriptOutput, err error) {
inputBytes, err := json.Marshal(input)
if err != nil {
return
}
cmd := exec.Command(config.Command, config.Args...)
stdin, err := cmd.StdinPipe()
outBytes, err := RunScript(config.Command, config.Args, inputBytes, config.Timeout, config.KillTimeout)
if err != nil {
return
}
stdout, err := cmd.StdoutPipe()
err = json.Unmarshal(outBytes, &output)
if err != nil {
return
}
channel := make(chan authScriptResponse, 1)
err = cmd.Start()
if err != nil {
return
if output.Error != "" {
err = fmt.Errorf("Authentication process reported error: %s", output.Error)
}
stdin.Write(inputBytes)
stdin.Write([]byte{'\n'})
// lots of potential race conditions here. we want to ensure that Wait()
// will be called, and will return, on the other goroutine, no matter
// where it is blocked. If it's blocked on ReadBytes(), we will kill it
// (first with SIGTERM, then with SIGKILL) and ReadBytes will return
// with EOF. If it's blocked on Wait(), then one of the kill signals
// will succeed and unblock it.
go processAuthScriptOutput(cmd, stdout, channel)
outputTimer := time.NewTimer(config.Timeout)
select {
case response := <-channel:
return response.output, response.err
case <-outputTimer.C:
}
err = errTimedOut
cmd.Process.Signal(syscall.SIGTERM)
termTimer := time.NewTimer(config.Timeout)
select {
case <-channel:
return
case <-termTimer.C:
}
cmd.Process.Kill()
return
}
func processAuthScriptOutput(cmd *exec.Cmd, stdout io.Reader, channel chan authScriptResponse) {
var response authScriptResponse
var out AuthScriptOutput
type IPScriptResult uint
reader := bufio.NewReader(stdout)
outBytes, err := reader.ReadBytes('\n')
if err == nil {
err = json.Unmarshal(outBytes, &out)
if err == nil {
response.output = out
if out.Error != "" {
err = fmt.Errorf("Authentication process reported error: %s", out.Error)
}
}
}
response.err = err
const (
IPNotChecked IPScriptResult = 0
IPAccepted IPScriptResult = 1
IPBanned IPScriptResult = 2
IPRequireSASL IPScriptResult = 3
)
// always call Wait() to ensure resource cleanup
err = cmd.Wait()
if err != nil {
response.err = err
}
channel <- response
type IPScriptInput struct {
IP string `json:"ip"`
}
type IPScriptOutput struct {
Result IPScriptResult `json:"result"`
BanMessage string `json:"banMessage"`
// for caching: the network to which this result is applicable, and a TTL in seconds:
CacheNet string `json:"cacheNet"`
CacheSeconds int `json:"cacheSeconds"`
Error string `json:"error"`
}
func CheckIPBan(sem utils.Semaphore, config ScriptConfig, addr net.IP) (output IPScriptOutput, err error) {
if sem != nil {
sem.Acquire()
defer sem.Release()
}
inputBytes, err := json.Marshal(IPScriptInput{IP: addr.String()})
if err != nil {
return
}
outBytes, err := RunScript(config.Command, config.Args, inputBytes, config.Timeout, config.KillTimeout)
if err != nil {
return
}
err = json.Unmarshal(outBytes, &output)
if err != nil {
return
}
if output.Error != "" {
err = fmt.Errorf("IP ban process reported error: %s", output.Error)
} else if !(IPAccepted <= output.Result && output.Result <= IPRequireSASL) {
err = fmt.Errorf("Invalid result from IP checking script: %d", output.Result)
}
return
}