forked from External/ergo
implement SASL OAUTHBEARER and draft/bearer (#2122)
* implement SASL OAUTHBEARER and draft/bearer * Upgrade JWT lib * Fix an edge case in SASL EXTERNAL * Accept longer SASL responses * review fix: allow multiple token definitions * enhance tests * use SASL utilities from irc-go * test expired tokens
This commit is contained in:
parent
8475b62da4
commit
ee7f818674
58 changed files with 2868 additions and 975 deletions
|
|
@ -4,6 +4,7 @@
|
|||
package irc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
|
|
@ -19,10 +20,12 @@ import (
|
|||
"github.com/tidwall/buntdb"
|
||||
"github.com/xdg-go/scram"
|
||||
|
||||
"github.com/ergochat/ergo/irc/caps"
|
||||
"github.com/ergochat/ergo/irc/connection_limits"
|
||||
"github.com/ergochat/ergo/irc/email"
|
||||
"github.com/ergochat/ergo/irc/migrations"
|
||||
"github.com/ergochat/ergo/irc/modes"
|
||||
"github.com/ergochat/ergo/irc/oauth2"
|
||||
"github.com/ergochat/ergo/irc/passwd"
|
||||
"github.com/ergochat/ergo/irc/utils"
|
||||
)
|
||||
|
|
@ -1395,6 +1398,10 @@ func (am *AccountManager) AuthenticateByPassphrase(client *Client, accountName s
|
|||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(accountName, caps.BearerTokenPrefix) {
|
||||
return am.AuthenticateByBearerToken(client, strings.TrimPrefix(accountName, caps.BearerTokenPrefix), passphrase)
|
||||
}
|
||||
|
||||
if throttled, remainingTime := client.checkLoginThrottle(); throttled {
|
||||
return &ThrottleError{remainingTime}
|
||||
}
|
||||
|
|
@ -1427,6 +1434,71 @@ func (am *AccountManager) AuthenticateByPassphrase(client *Client, accountName s
|
|||
return err
|
||||
}
|
||||
|
||||
func (am *AccountManager) AuthenticateByBearerToken(client *Client, tokenType, token string) (err error) {
|
||||
switch tokenType {
|
||||
case "oauth2":
|
||||
return am.AuthenticateByOAuthBearer(client, oauth2.OAuthBearerOptions{Token: token})
|
||||
case "jwt":
|
||||
return am.AuthenticateByJWT(client, token)
|
||||
default:
|
||||
return errInvalidBearerTokenType
|
||||
}
|
||||
}
|
||||
|
||||
func (am *AccountManager) AuthenticateByOAuthBearer(client *Client, opts oauth2.OAuthBearerOptions) (err error) {
|
||||
config := am.server.Config()
|
||||
|
||||
// we need to check this here since we can get here via SASL PLAIN:
|
||||
if !config.Accounts.OAuth2.Enabled {
|
||||
return errFeatureDisabled
|
||||
}
|
||||
|
||||
var username string
|
||||
if config.Accounts.AuthScript.Enabled && config.Accounts.OAuth2.AuthScript {
|
||||
username, err = am.authenticateByOAuthBearerScript(client, config, opts)
|
||||
} else {
|
||||
username, err = config.Accounts.OAuth2.Introspect(context.Background(), opts.Token)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
account, err := am.loadWithAutocreation(username, config.Accounts.OAuth2.Autocreate)
|
||||
if err == nil {
|
||||
am.Login(client, account)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (am *AccountManager) AuthenticateByJWT(client *Client, token string) (err error) {
|
||||
config := am.server.Config()
|
||||
// enabled check is encapsulated here:
|
||||
accountName, err := config.Accounts.JWTAuth.Validate(token)
|
||||
if err != nil {
|
||||
am.server.logger.Debug("accounts", "invalid JWT token", err.Error())
|
||||
return errAccountInvalidCredentials
|
||||
}
|
||||
account, err := am.loadWithAutocreation(accountName, config.Accounts.JWTAuth.Autocreate)
|
||||
if err == nil {
|
||||
am.Login(client, account)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (am *AccountManager) authenticateByOAuthBearerScript(client *Client, config *Config, opts oauth2.OAuthBearerOptions) (username string, err error) {
|
||||
output, err := CheckAuthScript(am.server.semaphores.AuthScript, config.Accounts.AuthScript.ScriptConfig,
|
||||
AuthScriptInput{OAuthBearer: &opts, IP: client.IP().String()})
|
||||
|
||||
if err != nil {
|
||||
am.server.logger.Error("internal", "failed shell auth invocation", err.Error())
|
||||
return "", oauth2.ErrInvalidToken
|
||||
} else if output.Success {
|
||||
return output.AccountName, nil
|
||||
} else {
|
||||
return "", oauth2.ErrInvalidToken
|
||||
}
|
||||
}
|
||||
|
||||
// AllNicks returns the uncasefolded nicknames for all accounts, including additional (grouped) nicks.
|
||||
func (am *AccountManager) AllNicks() (result []string) {
|
||||
accountNamePrefix := fmt.Sprintf(keyAccountName, "")
|
||||
|
|
@ -1939,8 +2011,10 @@ func (am *AccountManager) AuthenticateByCertificate(client *Client, certfp strin
|
|||
return err
|
||||
}
|
||||
|
||||
if authzid != "" && authzid != account {
|
||||
return errAuthzidAuthcidMismatch
|
||||
if authzid != "" {
|
||||
if cfAuthzid, err := CasefoldName(authzid); err != nil || cfAuthzid != account {
|
||||
return errAuthzidAuthcidMismatch
|
||||
}
|
||||
}
|
||||
|
||||
// ok, we found an account corresponding to their certificate
|
||||
|
|
@ -2145,6 +2219,7 @@ var (
|
|||
"PLAIN": authPlainHandler,
|
||||
"EXTERNAL": authExternalHandler,
|
||||
"SCRAM-SHA-256": authScramHandler,
|
||||
"OAUTHBEARER": authOauthBearerHandler,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/ergochat/ergo/irc/oauth2"
|
||||
"github.com/ergochat/ergo/irc/utils"
|
||||
)
|
||||
|
||||
|
|
@ -20,7 +21,8 @@ type AuthScriptInput struct {
|
|||
Certfp string `json:"certfp,omitempty"`
|
||||
PeerCerts []string `json:"peerCerts,omitempty"`
|
||||
peerCerts []*x509.Certificate
|
||||
IP string `json:"ip,omitempty"`
|
||||
IP string `json:"ip,omitempty"`
|
||||
OAuthBearer *oauth2.OAuthBearerOptions `json:"oauth2,omitempty"`
|
||||
}
|
||||
|
||||
type AuthScriptOutput struct {
|
||||
|
|
|
|||
|
|
@ -64,6 +64,10 @@ const (
|
|||
BotTagName = "bot"
|
||||
// https://ircv3.net/specs/extensions/chathistory
|
||||
ChathistoryTargetsBatchType = "draft/chathistory-targets"
|
||||
|
||||
// draft/bearer defines this prefix namespace for authcids, enabling tunneling bearer tokens
|
||||
// in SASL PLAIN:
|
||||
BearerTokenPrefix = "*bearer*"
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ package caps
|
|||
|
||||
const (
|
||||
// number of recognized capabilities:
|
||||
numCapabs = 34
|
||||
numCapabs = 35
|
||||
// length of the uint32 array that represents the bitset:
|
||||
bitsetLen = 2
|
||||
)
|
||||
|
|
@ -41,6 +41,10 @@ const (
|
|||
// https://github.com/ircv3/ircv3-specifications/pull/435
|
||||
AccountRegistration Capability = iota
|
||||
|
||||
// Bearer is the proposed IRCv3 capability named "draft/bearer":
|
||||
// https://gist.github.com/slingamn/4fabc7a3d5f335da7bb313a7f0648f37
|
||||
Bearer Capability = iota
|
||||
|
||||
// ChannelRename is the draft IRCv3 capability named "draft/channel-rename":
|
||||
// https://ircv3.net/specs/extensions/channel-rename
|
||||
ChannelRename Capability = iota
|
||||
|
|
@ -160,6 +164,7 @@ var (
|
|||
"cap-notify",
|
||||
"chghost",
|
||||
"draft/account-registration",
|
||||
"draft/bearer",
|
||||
"draft/channel-rename",
|
||||
"draft/chathistory",
|
||||
"draft/event-playback",
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ import (
|
|||
"github.com/ergochat/irc-go/ircfmt"
|
||||
"github.com/ergochat/irc-go/ircmsg"
|
||||
"github.com/ergochat/irc-go/ircreader"
|
||||
"github.com/ergochat/irc-go/ircutils"
|
||||
"github.com/xdg-go/scram"
|
||||
|
||||
"github.com/ergochat/ergo/irc/caps"
|
||||
|
|
@ -28,6 +29,7 @@ import (
|
|||
"github.com/ergochat/ergo/irc/flatip"
|
||||
"github.com/ergochat/ergo/irc/history"
|
||||
"github.com/ergochat/ergo/irc/modes"
|
||||
"github.com/ergochat/ergo/irc/oauth2"
|
||||
"github.com/ergochat/ergo/irc/sno"
|
||||
"github.com/ergochat/ergo/irc/utils"
|
||||
)
|
||||
|
|
@ -119,12 +121,20 @@ type Client struct {
|
|||
|
||||
type saslStatus struct {
|
||||
mechanism string
|
||||
value string
|
||||
value ircutils.SASLBuffer
|
||||
scramConv *scram.ServerConversation
|
||||
oauthConv *oauth2.OAuthBearerServer
|
||||
}
|
||||
|
||||
func (s *saslStatus) Initialize() {
|
||||
s.value.Initialize(saslMaxResponseLength)
|
||||
}
|
||||
|
||||
func (s *saslStatus) Clear() {
|
||||
*s = saslStatus{}
|
||||
s.mechanism = ""
|
||||
s.value.Clear()
|
||||
s.scramConv = nil
|
||||
s.oauthConv = nil
|
||||
}
|
||||
|
||||
// what stage the client is at w.r.t. the PASS command:
|
||||
|
|
@ -362,6 +372,7 @@ func (server *Server) RunClient(conn IRCConn) {
|
|||
isTor: wConn.Tor,
|
||||
hideSTS: wConn.Tor || wConn.HideSTS,
|
||||
}
|
||||
session.sasl.Initialize()
|
||||
client.sessions = []*Session{session}
|
||||
|
||||
session.resetFakelag()
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ import (
|
|||
"github.com/ergochat/ergo/irc/logger"
|
||||
"github.com/ergochat/ergo/irc/modes"
|
||||
"github.com/ergochat/ergo/irc/mysql"
|
||||
"github.com/ergochat/ergo/irc/oauth2"
|
||||
"github.com/ergochat/ergo/irc/passwd"
|
||||
"github.com/ergochat/ergo/irc/utils"
|
||||
)
|
||||
|
|
@ -331,7 +332,9 @@ type AccountConfig struct {
|
|||
Multiclient MulticlientConfig
|
||||
Bouncer *MulticlientConfig // # handle old name for 'multiclient'
|
||||
VHosts VHostConfig
|
||||
AuthScript AuthScriptConfig `yaml:"auth-script"`
|
||||
AuthScript AuthScriptConfig `yaml:"auth-script"`
|
||||
OAuth2 oauth2.OAuth2BearerConfig `yaml:"oauth2"`
|
||||
JWTAuth jwt.JWTAuthConfig `yaml:"jwt-auth"`
|
||||
}
|
||||
|
||||
type ScriptConfig struct {
|
||||
|
|
@ -1391,15 +1394,44 @@ func LoadConfig(filename string) (config *Config, err error) {
|
|||
config.Accounts.VHosts.validRegexp = defaultValidVhostRegex
|
||||
}
|
||||
|
||||
saslCapValue := "PLAIN,EXTERNAL,SCRAM-SHA-256"
|
||||
if !config.Accounts.AdvertiseSCRAM {
|
||||
saslCapValue = "PLAIN,EXTERNAL"
|
||||
}
|
||||
config.Server.capValues[caps.SASL] = saslCapValue
|
||||
if !config.Accounts.AuthenticationEnabled {
|
||||
if config.Accounts.AuthenticationEnabled {
|
||||
saslCapValues := []string{"PLAIN", "EXTERNAL"}
|
||||
if config.Accounts.AdvertiseSCRAM {
|
||||
saslCapValues = append(saslCapValues, "SCRAM-SHA-256")
|
||||
}
|
||||
if config.Accounts.OAuth2.Enabled {
|
||||
saslCapValues = append(saslCapValues, "OAUTHBEARER")
|
||||
}
|
||||
config.Server.capValues[caps.SASL] = strings.Join(saslCapValues, ",")
|
||||
} else {
|
||||
config.Server.supportedCaps.Disable(caps.SASL)
|
||||
}
|
||||
|
||||
if err := config.Accounts.OAuth2.Postprocess(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := config.Accounts.JWTAuth.Postprocess(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if config.Accounts.OAuth2.Enabled && config.Accounts.OAuth2.AuthScript && !config.Accounts.AuthScript.Enabled {
|
||||
return nil, fmt.Errorf("oauth2 is enabled with auth-script, but no auth-script is enabled")
|
||||
}
|
||||
|
||||
var bearerCapValues []string
|
||||
if config.Accounts.OAuth2.Enabled {
|
||||
bearerCapValues = append(bearerCapValues, "oauth2")
|
||||
}
|
||||
if config.Accounts.JWTAuth.Enabled {
|
||||
bearerCapValues = append(bearerCapValues, "jwt")
|
||||
}
|
||||
if len(bearerCapValues) != 0 {
|
||||
config.Server.capValues[caps.Bearer] = strings.Join(bearerCapValues, ",")
|
||||
} else {
|
||||
config.Server.supportedCaps.Disable(caps.Bearer)
|
||||
}
|
||||
|
||||
if !config.Accounts.Registration.Enabled {
|
||||
config.Server.supportedCaps.Disable(caps.AccountRegistration)
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -76,6 +76,7 @@ var (
|
|||
errValidEmailRequired = errors.New("A valid email address is required for account registration")
|
||||
errInvalidAccountRename = errors.New("Account renames can only change the casefolding of the account name")
|
||||
errNameReserved = errors.New(`Name reserved due to a prior registration`)
|
||||
errInvalidBearerTokenType = errors.New("invalid bearer token type")
|
||||
)
|
||||
|
||||
// String Errors
|
||||
|
|
|
|||
162
irc/handlers.go
162
irc/handlers.go
|
|
@ -8,7 +8,6 @@ package irc
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
|
|
@ -31,6 +30,7 @@ import (
|
|||
"github.com/ergochat/ergo/irc/history"
|
||||
"github.com/ergochat/ergo/irc/jwt"
|
||||
"github.com/ergochat/ergo/irc/modes"
|
||||
"github.com/ergochat/ergo/irc/oauth2"
|
||||
"github.com/ergochat/ergo/irc/sno"
|
||||
"github.com/ergochat/ergo/irc/utils"
|
||||
)
|
||||
|
|
@ -178,6 +178,10 @@ func acceptHandler(server *Server, client *Client, msg ircmsg.Message, rb *Respo
|
|||
return false
|
||||
}
|
||||
|
||||
const (
|
||||
saslMaxResponseLength = 8192 // implementation-defined sanity check, long enough for bearer tokens
|
||||
)
|
||||
|
||||
// AUTHENTICATE [<mechanism>|<data>|*]
|
||||
func authenticateHandler(server *Server, client *Client, msg ircmsg.Message, rb *ResponseBuffer) bool {
|
||||
session := rb.session
|
||||
|
|
@ -201,7 +205,7 @@ func authenticateHandler(server *Server, client *Client, msg ircmsg.Message, rb
|
|||
return false
|
||||
}
|
||||
|
||||
// start new sasl session
|
||||
// start new sasl session: parameter is the authentication mechanism
|
||||
if session.sasl.mechanism == "" {
|
||||
throttled, remainingTime := client.loginThrottle.Touch()
|
||||
if throttled {
|
||||
|
|
@ -213,6 +217,16 @@ func authenticateHandler(server *Server, client *Client, msg ircmsg.Message, rb
|
|||
mechanism := strings.ToUpper(msg.Params[0])
|
||||
_, mechanismIsEnabled := EnabledSaslMechanisms[mechanism]
|
||||
|
||||
// The spec says: "The AUTHENTICATE command MUST be used before registration
|
||||
// is complete and with the sasl capability enabled." Enforcing this universally
|
||||
// would simplify the implementation somewhat, but we've never enforced it before
|
||||
// and I don't want to break working clients that use PLAIN or EXTERNAL
|
||||
// and violate this MUST (e.g. by sending CAP END too early).
|
||||
if client.registered && !(mechanism == "PLAIN" || mechanism == "EXTERNAL") {
|
||||
rb.Add(nil, server.name, ERR_SASLFAIL, details.nick, client.t("SASL is only allowed before connection registration"))
|
||||
return false
|
||||
}
|
||||
|
||||
if mechanismIsEnabled {
|
||||
session.sasl.mechanism = mechanism
|
||||
if !config.Server.Compatibility.SendUnprefixedSasl {
|
||||
|
|
@ -230,46 +244,28 @@ func authenticateHandler(server *Server, client *Client, msg ircmsg.Message, rb
|
|||
return false
|
||||
}
|
||||
|
||||
// continue existing sasl session
|
||||
rawData := msg.Params[0]
|
||||
|
||||
// https://ircv3.net/specs/extensions/sasl-3.1:
|
||||
// "The response is encoded in Base64 (RFC 4648), then split to 400-byte chunks,
|
||||
// and each chunk is sent as a separate AUTHENTICATE command."
|
||||
saslMaxArgLength := 400
|
||||
if len(rawData) > saslMaxArgLength {
|
||||
// continue existing sasl session: parameter is a message chunk
|
||||
done, value, err := session.sasl.value.Add(msg.Params[0])
|
||||
if err == nil {
|
||||
if done {
|
||||
// call actual handler
|
||||
handler := EnabledSaslMechanisms[session.sasl.mechanism]
|
||||
return handler(server, client, session, value, rb)
|
||||
} else {
|
||||
return false // wait for continuation line
|
||||
}
|
||||
}
|
||||
// else: error handling
|
||||
switch err {
|
||||
case ircutils.ErrSASLTooLong:
|
||||
rb.Add(nil, server.name, ERR_SASLTOOLONG, details.nick, client.t("SASL message too long"))
|
||||
session.sasl.Clear()
|
||||
return false
|
||||
} else if len(rawData) == saslMaxArgLength {
|
||||
// allow 4 'continuation' lines before rejecting for length
|
||||
if len(session.sasl.value) >= saslMaxArgLength*4 {
|
||||
rb.Add(nil, server.name, ERR_SASLFAIL, details.nick, client.t("SASL authentication failed: Passphrase too long"))
|
||||
session.sasl.Clear()
|
||||
return false
|
||||
}
|
||||
session.sasl.value += rawData
|
||||
return false
|
||||
case ircutils.ErrSASLLimitExceeded:
|
||||
rb.Add(nil, server.name, ERR_SASLFAIL, details.nick, client.t("SASL authentication failed: Passphrase too long"))
|
||||
default:
|
||||
rb.Add(nil, server.name, ERR_SASLFAIL, details.nick, client.t("SASL authentication failed: Invalid b64 encoding"))
|
||||
}
|
||||
if rawData != "+" {
|
||||
session.sasl.value += rawData
|
||||
}
|
||||
|
||||
var data []byte
|
||||
var err error
|
||||
if session.sasl.value != "+" {
|
||||
data, err = base64.StdEncoding.DecodeString(session.sasl.value)
|
||||
session.sasl.value = ""
|
||||
if err != nil {
|
||||
rb.Add(nil, server.name, ERR_SASLFAIL, details.nick, client.t("SASL authentication failed: Invalid b64 encoding"))
|
||||
session.sasl.Clear()
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// call actual handler
|
||||
handler := EnabledSaslMechanisms[session.sasl.mechanism]
|
||||
return handler(server, client, session, data, rb)
|
||||
session.sasl.Clear()
|
||||
return false
|
||||
}
|
||||
|
||||
// AUTHENTICATE PLAIN
|
||||
|
|
@ -331,7 +327,7 @@ func authErrorToMessage(server *Server, err error) (msg string) {
|
|||
}
|
||||
|
||||
switch err {
|
||||
case errAccountDoesNotExist, errAccountUnverified, errAccountInvalidCredentials, errAuthzidAuthcidMismatch, errNickAccountMismatch, errAccountSuspended:
|
||||
case errAccountDoesNotExist, errAccountUnverified, errAccountInvalidCredentials, errAuthzidAuthcidMismatch, errNickAccountMismatch, errAccountSuspended, oauth2.ErrInvalidToken:
|
||||
return err.Error()
|
||||
default:
|
||||
// don't expose arbitrary error messages to the user
|
||||
|
|
@ -351,28 +347,18 @@ func authExternalHandler(server *Server, client *Client, session *Session, value
|
|||
|
||||
// EXTERNAL doesn't carry an authentication ID (this is determined from the
|
||||
// certificate), but does carry an optional authorization ID.
|
||||
var authzid string
|
||||
authzid := string(value)
|
||||
var deviceID string
|
||||
var err error
|
||||
if len(value) != 0 {
|
||||
authzid, err = CasefoldName(string(value))
|
||||
if err != nil {
|
||||
err = errAuthzidAuthcidMismatch
|
||||
}
|
||||
// see #843: strip the device ID for the benefit of clients that don't
|
||||
// distinguish user/ident from account name
|
||||
if strudelIndex := strings.IndexByte(authzid, '@'); strudelIndex != -1 {
|
||||
authzid, deviceID = authzid[:strudelIndex], authzid[strudelIndex+1:]
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
// see #843: strip the device ID for the benefit of clients that don't
|
||||
// distinguish user/ident from account name
|
||||
if strudelIndex := strings.IndexByte(authzid, '@'); strudelIndex != -1 {
|
||||
var deviceID string
|
||||
authzid, deviceID = authzid[:strudelIndex], authzid[strudelIndex+1:]
|
||||
if !client.registered {
|
||||
rb.session.deviceID = deviceID
|
||||
}
|
||||
}
|
||||
err = server.accounts.AuthenticateByCertificate(client, rb.session.certfp, rb.session.peerCerts, authzid)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
sendAuthErrorResponse(client, rb, err)
|
||||
return false
|
||||
|
|
@ -381,6 +367,9 @@ func authExternalHandler(server *Server, client *Client, session *Session, value
|
|||
}
|
||||
|
||||
sendSuccessfulAccountAuth(nil, client, rb, true)
|
||||
if !client.registered && deviceID != "" {
|
||||
rb.session.deviceID = deviceID
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
|
|
@ -418,9 +407,8 @@ func authScramHandler(server *Server, client *Client, session *Session, value []
|
|||
account, err := server.accounts.LoadAccount(authcid)
|
||||
if err == nil {
|
||||
server.accounts.Login(client, account)
|
||||
if fixupNickEqualsAccount(client, rb, server.Config(), "") {
|
||||
sendSuccessfulAccountAuth(nil, client, rb, true)
|
||||
}
|
||||
// fixupNickEqualsAccount is not needed for unregistered clients
|
||||
sendSuccessfulAccountAuth(nil, client, rb, true)
|
||||
} else {
|
||||
server.logger.Error("internal", "SCRAM succeeded but couldn't load account", authcid, err.Error())
|
||||
rb.Add(nil, server.name, ERR_SASLFAIL, client.nick, client.t("SASL authentication failed"))
|
||||
|
|
@ -433,7 +421,7 @@ func authScramHandler(server *Server, client *Client, session *Session, value []
|
|||
|
||||
response, err := session.sasl.scramConv.Step(string(value))
|
||||
if err == nil {
|
||||
rb.Add(nil, server.name, "AUTHENTICATE", base64.StdEncoding.EncodeToString([]byte(response)))
|
||||
sendSASLChallenge(server, rb, []byte(response))
|
||||
} else {
|
||||
continueAuth = false
|
||||
rb.Add(nil, server.name, ERR_SASLFAIL, client.Nick(), err.Error())
|
||||
|
|
@ -443,6 +431,58 @@ func authScramHandler(server *Server, client *Client, session *Session, value []
|
|||
return false
|
||||
}
|
||||
|
||||
// AUTHENTICATE OAUTHBEARER
|
||||
func authOauthBearerHandler(server *Server, client *Client, session *Session, value []byte, rb *ResponseBuffer) bool {
|
||||
if !server.Config().Accounts.OAuth2.Enabled {
|
||||
rb.Add(nil, server.name, ERR_SASLFAIL, client.Nick(), "SASL authentication failed: mechanism not enabled")
|
||||
return false
|
||||
}
|
||||
|
||||
if session.sasl.oauthConv == nil {
|
||||
session.sasl.oauthConv = oauth2.NewOAuthBearerServer(
|
||||
func(opts oauth2.OAuthBearerOptions) *oauth2.OAuthBearerError {
|
||||
err := server.accounts.AuthenticateByOAuthBearer(client, opts)
|
||||
switch err {
|
||||
case nil:
|
||||
return nil
|
||||
case oauth2.ErrInvalidToken:
|
||||
return &oauth2.OAuthBearerError{Status: "invalid_token", Schemes: "bearer"}
|
||||
case errFeatureDisabled:
|
||||
return &oauth2.OAuthBearerError{Status: "invalid_request", Schemes: "bearer"}
|
||||
default:
|
||||
// this is probably a misconfiguration or infrastructure error so we should log it
|
||||
server.logger.Error("internal", "failed to validate OAUTHBEARER token", err.Error())
|
||||
// tell the client it was their fault even though it probably wasn't:
|
||||
return &oauth2.OAuthBearerError{Status: "invalid_request", Schemes: "bearer"}
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
challenge, done, err := session.sasl.oauthConv.Next(value)
|
||||
if done {
|
||||
if err == nil {
|
||||
sendSuccessfulAccountAuth(nil, client, rb, true)
|
||||
} else {
|
||||
rb.Add(nil, server.name, ERR_SASLFAIL, client.Nick(), ircutils.SanitizeText(err.Error(), 350))
|
||||
}
|
||||
session.sasl.Clear()
|
||||
} else {
|
||||
// ignore `err`, we need to relay the challenge (which may contain a JSON-encoded error)
|
||||
// to the client
|
||||
sendSASLChallenge(server, rb, challenge)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// helper to b64 a sasl response and chunk it into 400-byte lines
|
||||
// as per https://ircv3.net/specs/extensions/sasl-3.1
|
||||
func sendSASLChallenge(server *Server, rb *ResponseBuffer, challenge []byte) {
|
||||
for _, chunk := range ircutils.EncodeSASLResponse(challenge) {
|
||||
rb.Add(nil, server.name, "AUTHENTICATE", chunk)
|
||||
}
|
||||
}
|
||||
|
||||
// AWAY [<message>]
|
||||
func awayHandler(server *Server, client *Client, msg ircmsg.Message, rb *ResponseBuffer) bool {
|
||||
// #1996: `AWAY :` is treated the same as `AWAY`
|
||||
|
|
|
|||
157
irc/jwt/bearer.go
Normal file
157
irc/jwt/bearer.go
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
// Copyright (c) 2024 Shivaram Lingamneni <slingamn@cs.stanford.edu>
|
||||
// released under the MIT license
|
||||
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
jwt "github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrAuthDisabled = fmt.Errorf("JWT authentication is disabled")
|
||||
ErrNoValidAccountClaim = fmt.Errorf("JWT token did not contain an acceptable account name claim")
|
||||
)
|
||||
|
||||
// JWTAuthConfig is the config for Ergo to accept JWTs via draft/bearer
|
||||
type JWTAuthConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Autocreate bool `yaml:"autocreate"`
|
||||
Tokens []JWTAuthTokenConfig `yaml:"tokens"`
|
||||
}
|
||||
|
||||
type JWTAuthTokenConfig struct {
|
||||
Algorithm string `yaml:"algorithm"`
|
||||
KeyString string `yaml:"key"`
|
||||
KeyFile string `yaml:"key-file"`
|
||||
key any
|
||||
parser *jwt.Parser
|
||||
AccountClaims []string `yaml:"account-claims"`
|
||||
StripDomain string `yaml:"strip-domain"`
|
||||
}
|
||||
|
||||
func (j *JWTAuthConfig) Postprocess() error {
|
||||
if !j.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(j.Tokens) == 0 {
|
||||
return fmt.Errorf("JWT authentication enabled, but no valid tokens defined")
|
||||
}
|
||||
|
||||
for i := range j.Tokens {
|
||||
if err := j.Tokens[i].Postprocess(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (j *JWTAuthTokenConfig) Postprocess() error {
|
||||
keyBytes, err := j.keyBytes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
j.Algorithm = strings.ToLower(j.Algorithm)
|
||||
|
||||
var methods []string
|
||||
switch j.Algorithm {
|
||||
case "hmac":
|
||||
j.key = keyBytes
|
||||
methods = []string{"HS256", "HS384", "HS512"}
|
||||
case "rsa":
|
||||
rsaKey, err := jwt.ParseRSAPublicKeyFromPEM(keyBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
j.key = rsaKey
|
||||
methods = []string{"RS256", "RS384", "RS512"}
|
||||
case "eddsa":
|
||||
eddsaKey, err := jwt.ParseEdPublicKeyFromPEM(keyBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
j.key = eddsaKey
|
||||
methods = []string{"EdDSA"}
|
||||
default:
|
||||
return fmt.Errorf("invalid jwt algorithm: %s", j.Algorithm)
|
||||
}
|
||||
j.parser = jwt.NewParser(jwt.WithValidMethods(methods))
|
||||
|
||||
if len(j.AccountClaims) == 0 {
|
||||
return fmt.Errorf("JWT auth enabled, but no account-claims specified")
|
||||
}
|
||||
|
||||
j.StripDomain = strings.ToLower(j.StripDomain)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (j *JWTAuthConfig) Validate(t string) (accountName string, err error) {
|
||||
if !j.Enabled || len(j.Tokens) == 0 {
|
||||
return "", ErrAuthDisabled
|
||||
}
|
||||
|
||||
for i := range j.Tokens {
|
||||
accountName, err = j.Tokens[i].Validate(t)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (j *JWTAuthTokenConfig) keyBytes() (result []byte, err error) {
|
||||
if j.KeyFile != "" {
|
||||
o, err := os.Open(j.KeyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return io.ReadAll(o)
|
||||
}
|
||||
if j.KeyString != "" {
|
||||
return []byte(j.KeyString), nil
|
||||
}
|
||||
return nil, fmt.Errorf("JWT auth enabled, but no JWT key specified")
|
||||
}
|
||||
|
||||
// implements jwt.Keyfunc
|
||||
func (j *JWTAuthTokenConfig) keyFunc(_ *jwt.Token) (interface{}, error) {
|
||||
return j.key, nil
|
||||
}
|
||||
|
||||
func (j *JWTAuthTokenConfig) Validate(t string) (accountName string, err error) {
|
||||
token, err := j.parser.Parse(t, j.keyFunc)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
// impossible with Parse (as opposed to ParseWithClaims)
|
||||
return "", fmt.Errorf("unexpected type from parsed token claims: %T", claims)
|
||||
}
|
||||
|
||||
for _, c := range j.AccountClaims {
|
||||
if v, ok := claims[c]; ok {
|
||||
if vstr, ok := v.(string); ok {
|
||||
// validate and strip email addresses:
|
||||
if idx := strings.IndexByte(vstr, '@'); idx != -1 {
|
||||
suffix := vstr[idx+1:]
|
||||
vstr = vstr[:idx]
|
||||
if strings.ToLower(suffix) != j.StripDomain {
|
||||
continue
|
||||
}
|
||||
}
|
||||
return vstr, nil // success
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", ErrNoValidAccountClaim
|
||||
}
|
||||
143
irc/jwt/bearer_test.go
Normal file
143
irc/jwt/bearer_test.go
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
package jwt
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
jwt "github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
const (
|
||||
rsaTestPubKey = `-----BEGIN PUBLIC KEY-----
|
||||
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAwhcCcXrfR/GmoPKxBi0H
|
||||
cUl2pUl4acq2m3abFtMMoYTydJdEhgYWfsXuragyEIVkJU1ZnrgedW0QJUcANRGO
|
||||
hP/B+MjBevDNsRXQECfhyjfzhz6KWZb4i7C2oImJuAjq/F4qGLdEGQDBpAzof8qv
|
||||
9Zt5iN3GXY/EQtQVMFyR/7BPcbPLbHlOtzZ6tVEioXuUxQoai7x3Kc0jIcPWuyGa
|
||||
Q04IvsgdaWO6oH4fhPfyVsmX37rYUn79zcqPHS4ieWM1KN9qc7W+/UJIeiwAStpJ
|
||||
8gv+OSMrijRZGgQGCeOO5U59GGJC4mqUczB+JFvrlAIv0rggNpl+qalngosNxukB
|
||||
uQIDAQAB
|
||||
-----END PUBLIC KEY-----`
|
||||
|
||||
rsaTestPrivKey = `-----BEGIN PRIVATE KEY-----
|
||||
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDCFwJxet9H8aag
|
||||
8rEGLQdxSXalSXhpyrabdpsW0wyhhPJ0l0SGBhZ+xe6tqDIQhWQlTVmeuB51bRAl
|
||||
RwA1EY6E/8H4yMF68M2xFdAQJ+HKN/OHPopZlviLsLagiYm4COr8XioYt0QZAMGk
|
||||
DOh/yq/1m3mI3cZdj8RC1BUwXJH/sE9xs8tseU63Nnq1USKhe5TFChqLvHcpzSMh
|
||||
w9a7IZpDTgi+yB1pY7qgfh+E9/JWyZffuthSfv3Nyo8dLiJ5YzUo32pztb79Qkh6
|
||||
LABK2knyC/45IyuKNFkaBAYJ447lTn0YYkLiapRzMH4kW+uUAi/SuCA2mX6pqWeC
|
||||
iw3G6QG5AgMBAAECggEARaAnejoP2ykvE1G8e3Cv2M33x/eBQMI9m6uCmz9+qnqc
|
||||
14JkTIfmjffHVXie7RpNAKys16lJE+rZ/eVoh6EStVdiaDLsZYP45evjRcho0Tgd
|
||||
Hokq7FSiOMpd2V09kE1yrrHA/DjSLv38eTNAPIejc8IgaR7VyD6Is0iNiVnL7iLa
|
||||
mj1zB6+dSeQ5ICYkrihb1gA+SvECsjLZ/5XESXEdHJvxhC0vLAdHmdQf3BPPlrGg
|
||||
VHondxL5gt6MFykpOxTFA6f5JkSefhUR/2OcCDpMs6a5GUytjl3rA3aGT6v3CbnR
|
||||
ykD6PzyC20EUADQYF2pmJfzbxyRqfNdbSJwQv5QQYQKBgQD4rFdvgZC97L7WhZ5T
|
||||
axW8hRW2dH24GIqFT4ZnCg0suyMNshyGvDMuBfGvokN/yACmvsdE0/f57esar+ye
|
||||
l9RC+CzGUch08Ke5WdqwACOCNDpx0kJcXKTuLIgkvthdla/oAQQ9T7OgEwDrvaR0
|
||||
m8s/Z7Hb3hLD3xdOt6Xjrv/6xQKBgQDHzvbcIkhmWdvaPDT9NEu7psR/fxF5UjqU
|
||||
Cca/bfHhySRQs3A1CF57pfwpUqAcSivNf7O+3NI62AKoyMDYv0ek2h6hGk6g5GJ1
|
||||
SuXYfjcbkL6SWNV0InsgmzCjvxhyms83xZq7uMClEBvkiKVMdt6zFkwW9eRKtUuZ
|
||||
pzVK5RfqZQKBgF5SME/xGw+O7su7ntQROAtrh1LPWKgtVs093sLSgzDGQoN9XWiV
|
||||
lewNASEXMPcUy3pzvm2S4OoBnj1fISb+e9py+7i1aI1CgrvBIzvCsbU/TjPCBr21
|
||||
vjFA3trhMHw+vJwJVqxSwNUkoCLKqcg5F5yTHllBIGj/A34uFlQIGrvpAoGAextm
|
||||
d+1bhExbLBQqZdOh0cWHjjKBVqm2U93OKcYY4Q9oI5zbRqGYbUCwo9k3sxZz9JJ4
|
||||
8eDmWsEaqlm+kA0SnFyTwJkP1wvAKhpykTf6xi4hbNP0+DACgu17Q3iLHJmLkQZc
|
||||
Nss3TrwlI2KZzgnzXo4fZYotFWasZMhkCngqiw0CgYEAmz2D70RYEauUNE1+zLhS
|
||||
6Ox5+PF/8Z0rZOlTghMTfqYcDJa+qQe9pJp7RPgilsgemqo0XtgLKz3ATE5FmMa4
|
||||
HRRGXPkMNu6Hzz4Yk4eM/yJqckoEc8azV25myqQ+7QXTwZEvxVbtUWZtxfImGwq+
|
||||
s/uzBKNwWf9UPTeIt+4JScg=
|
||||
-----END PRIVATE KEY-----`
|
||||
)
|
||||
|
||||
func TestJWTBearerAuth(t *testing.T) {
|
||||
j := JWTAuthConfig{
|
||||
Enabled: true,
|
||||
Tokens: []JWTAuthTokenConfig{
|
||||
{
|
||||
Algorithm: "rsa",
|
||||
KeyString: rsaTestPubKey,
|
||||
AccountClaims: []string{"preferred_username", "email"},
|
||||
StripDomain: "example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := j.Postprocess(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// fixed test vector signed with the RSA privkey:
|
||||
token := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJwcmVmZXJyZWRfdXNlcm5hbWUiOiJzbGluZ2FtbiJ9.caPZw2Dl4KZN-SErD5-WZB_lPPveHXaMCoUHxNebb94G9w3VaWDIRdngVU99JKx5nE_yRtpewkHHvXsQnNA_M63GBXGK7afXB8e-kV33QF3v9pXALMP5SzRwMgokyxas0RgHu4e4L0d7dn9o_nkdXp34GX3Pn1MVkUGBH6GdlbOdDHrs04pPQ0Qj-O2U0AIpnZq-X_GQs9ECJo4TlPKWR7Jlq5l9bS0dBnohea4FuqJr232je-dlRVkbCa7nrnFmsIsezsgA3Jb_j9Zu_iv460t_d2eaytbVp9P-DOVfzUfkBsKs-81URQEnTjW6ut445AJz2pxjX92X0GdmORpAkQ"
|
||||
accountName, err := j.Validate(token)
|
||||
if err != nil {
|
||||
t.Errorf("could not validate valid token: %v", err)
|
||||
}
|
||||
if accountName != "slingamn" {
|
||||
t.Errorf("incorrect account name for token: `%s`", accountName)
|
||||
}
|
||||
|
||||
// programmatically sign a new token, validate it
|
||||
privKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(rsaTestPrivKey))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
jTok := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims(map[string]any{"preferred_username": "slingamn"}))
|
||||
token, err = jTok.SignedString(privKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
accountName, err = j.Validate(token)
|
||||
if err != nil {
|
||||
t.Errorf("could not validate valid token: %v", err)
|
||||
}
|
||||
if accountName != "slingamn" {
|
||||
t.Errorf("incorrect account name for token: `%s`", accountName)
|
||||
}
|
||||
|
||||
// test expiration
|
||||
jTok = jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims(map[string]any{"preferred_username": "slingamn", "exp": 1675740865}))
|
||||
token, err = jTok.SignedString(privKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
accountName, err = j.Validate(token)
|
||||
if err == nil {
|
||||
t.Errorf("validated expired token")
|
||||
}
|
||||
|
||||
// test for the infamous algorithm confusion bug
|
||||
jTok = jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims(map[string]any{"preferred_username": "slingamn"}))
|
||||
token, err = jTok.SignedString([]byte(rsaTestPubKey))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
accountName, err = j.Validate(token)
|
||||
if err == nil {
|
||||
t.Errorf("validated HS256 token despite RSA being required")
|
||||
}
|
||||
|
||||
// test no valid claims
|
||||
jTok = jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims(map[string]any{"sub": "slingamn"}))
|
||||
token, err = jTok.SignedString(privKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
accountName, err = j.Validate(token)
|
||||
if err != ErrNoValidAccountClaim {
|
||||
t.Errorf("expected ErrNoValidAccountClaim, got: %v", err)
|
||||
}
|
||||
|
||||
// test email addresses
|
||||
jTok = jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims(map[string]any{"email": "Slingamn@example.com"}))
|
||||
token, err = jTok.SignedString(privKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
accountName, err = j.Validate(token)
|
||||
if err != nil {
|
||||
t.Errorf("could not validate valid token: %v", err)
|
||||
}
|
||||
if accountName != "Slingamn" {
|
||||
t.Errorf("incorrect account name for token: `%s`", accountName)
|
||||
}
|
||||
}
|
||||
|
|
@ -6,18 +6,15 @@ package jwt
|
|||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
jwt "github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNoKeys = errors.New("No signing keys are enabled")
|
||||
ErrNoKeys = errors.New("No EXTJWT signing keys are enabled")
|
||||
)
|
||||
|
||||
type MapClaims jwt.MapClaims
|
||||
|
|
@ -38,22 +35,10 @@ func (t *JwtServiceConfig) Postprocess() (err error) {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d, _ := pem.Decode(keyBytes)
|
||||
t.rsaPrivateKey, err = jwt.ParseRSAPrivateKeyFromPEM(keyBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.rsaPrivateKey, err = x509.ParsePKCS1PrivateKey(d.Bytes)
|
||||
if err != nil {
|
||||
privateKey, err := x509.ParsePKCS8PrivateKey(d.Bytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if rsaPrivateKey, ok := privateKey.(*rsa.PrivateKey); ok {
|
||||
t.rsaPrivateKey = rsaPrivateKey
|
||||
} else {
|
||||
return fmt.Errorf("Non-RSA key type for extjwt: %T", privateKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
108
irc/oauth2/oauth2.go
Normal file
108
irc/oauth2/oauth2.go
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
// Copyright 2022-2023 Simon Ser <contact@emersion.fr>
|
||||
// Derived from https://git.sr.ht/~emersion/soju/tree/36d6cb19a4f90d217d55afb0b15318321baaad09/item/auth/oauth2.go
|
||||
// Originally released under the AGPLv3, relicensed to the Ergo project under the MIT license
|
||||
// Modifications copyright 2024 Shivaram Lingamneni <slingamn@cs.stanford.edu>
|
||||
// Released under the MIT license
|
||||
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrAuthDisabled = fmt.Errorf("OAuth 2.0 authentication is disabled")
|
||||
|
||||
// all cases where the infrastructure is working correctly, but we determined
|
||||
// that the user supplied an invalid token
|
||||
ErrInvalidToken = fmt.Errorf("OAuth 2.0 bearer token invalid")
|
||||
)
|
||||
|
||||
type OAuth2BearerConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Autocreate bool `yaml:"autocreate"`
|
||||
AuthScript bool `yaml:"auth-script"`
|
||||
IntrospectionURL string `yaml:"introspection-url"`
|
||||
IntrospectionTimeout time.Duration `yaml:"introspection-timeout"`
|
||||
// omit for `none`, required for `client_secret_basic`
|
||||
ClientID string `yaml:"client-id"`
|
||||
ClientSecret string `yaml:"client-secret"`
|
||||
}
|
||||
|
||||
func (o *OAuth2BearerConfig) Postprocess() error {
|
||||
if !o.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
if o.IntrospectionTimeout == 0 {
|
||||
return fmt.Errorf("a nonzero oauthbearer introspection timeout is required (try 10s)")
|
||||
}
|
||||
|
||||
if _, err := url.Parse(o.IntrospectionURL); err != nil {
|
||||
return fmt.Errorf("invalid introspection-url: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *OAuth2BearerConfig) Introspect(ctx context.Context, token string) (username string, err error) {
|
||||
if !o.Enabled {
|
||||
return "", ErrAuthDisabled
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, o.IntrospectionTimeout)
|
||||
defer cancel()
|
||||
|
||||
reqValues := make(url.Values)
|
||||
reqValues.Set("token", token)
|
||||
|
||||
reqBody := strings.NewReader(reqValues.Encode())
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, o.IntrospectionURL, reqBody)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create OAuth 2.0 introspection request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
if o.ClientID != "" {
|
||||
req.SetBasicAuth(url.QueryEscape(o.ClientID), url.QueryEscape(o.ClientSecret))
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to send OAuth 2.0 introspection request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("OAuth 2.0 introspection error: %v", resp.Status)
|
||||
}
|
||||
|
||||
var data oauth2Introspection
|
||||
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
|
||||
return "", fmt.Errorf("failed to decode OAuth 2.0 introspection response: %v", err)
|
||||
}
|
||||
|
||||
if !data.Active {
|
||||
return "", ErrInvalidToken
|
||||
}
|
||||
if data.Username == "" {
|
||||
// We really need the username here, otherwise an OAuth 2.0 user can
|
||||
// impersonate any other user.
|
||||
return "", fmt.Errorf("missing username in OAuth 2.0 introspection response")
|
||||
}
|
||||
|
||||
return data.Username, nil
|
||||
}
|
||||
|
||||
type oauth2Introspection struct {
|
||||
Active bool `json:"active"`
|
||||
Username string `json:"username"`
|
||||
}
|
||||
172
irc/oauth2/sasl.go
Normal file
172
irc/oauth2/sasl.go
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
package oauth2
|
||||
|
||||
/*
|
||||
https://github.com/emersion/go-sasl/blob/e73c9f7bad438a9bf3f5b28e661b74d752ecafdd/oauthbearer.go
|
||||
|
||||
Copyright 2019-2022 Simon Ser, Frode Aannevik, Max Mazurov
|
||||
Released under the MIT license
|
||||
*/
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUnexpectedClientResponse = errors.New("unexpected client response")
|
||||
)
|
||||
|
||||
// The OAUTHBEARER mechanism name.
|
||||
const OAuthBearer = "OAUTHBEARER"
|
||||
|
||||
type OAuthBearerError struct {
|
||||
Status string `json:"status"`
|
||||
Schemes string `json:"schemes"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
type OAuthBearerOptions struct {
|
||||
Username string `json:"username,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
Host string `json:"host,omitempty"`
|
||||
Port int `json:"port,omitempty"`
|
||||
}
|
||||
|
||||
func (err *OAuthBearerError) Error() string {
|
||||
return fmt.Sprintf("OAUTHBEARER authentication error (%v)", err.Status)
|
||||
}
|
||||
|
||||
type OAuthBearerAuthenticator func(opts OAuthBearerOptions) *OAuthBearerError
|
||||
|
||||
type OAuthBearerServer struct {
|
||||
done bool
|
||||
failErr error
|
||||
authenticate OAuthBearerAuthenticator
|
||||
}
|
||||
|
||||
func (a *OAuthBearerServer) fail(descr string) ([]byte, bool, error) {
|
||||
blob, err := json.Marshal(OAuthBearerError{
|
||||
Status: "invalid_request",
|
||||
Schemes: "bearer",
|
||||
})
|
||||
if err != nil {
|
||||
panic(err) // wtf
|
||||
}
|
||||
a.failErr = errors.New(descr)
|
||||
return blob, false, nil
|
||||
}
|
||||
|
||||
func (a *OAuthBearerServer) Next(response []byte) (challenge []byte, done bool, err error) {
|
||||
// Per RFC, we cannot just send an error, we need to return JSON-structured
|
||||
// value as a challenge and then after getting dummy response from the
|
||||
// client stop the exchange.
|
||||
if a.failErr != nil {
|
||||
// Server libraries (go-smtp, go-imap) will not call Next on
|
||||
// protocol-specific SASL cancel response ('*'). However, GS2 (and
|
||||
// indirectly OAUTHBEARER) defines a protocol-independent way to do so
|
||||
// using 0x01.
|
||||
if len(response) != 1 && response[0] != 0x01 {
|
||||
return nil, true, errors.New("unexpected response")
|
||||
}
|
||||
return nil, true, a.failErr
|
||||
}
|
||||
|
||||
if a.done {
|
||||
err = ErrUnexpectedClientResponse
|
||||
return
|
||||
}
|
||||
|
||||
// Generate empty challenge.
|
||||
if response == nil {
|
||||
return []byte{}, false, nil
|
||||
}
|
||||
|
||||
a.done = true
|
||||
|
||||
// Cut n,a=username,\x01host=...\x01auth=...
|
||||
// into
|
||||
// n
|
||||
// a=username
|
||||
// \x01host=...\x01auth=...\x01\x01
|
||||
parts := bytes.SplitN(response, []byte{','}, 3)
|
||||
if len(parts) != 3 {
|
||||
return a.fail("Invalid response")
|
||||
}
|
||||
flag := parts[0]
|
||||
authzid := parts[1]
|
||||
if !bytes.Equal(flag, []byte{'n'}) {
|
||||
return a.fail("Invalid response, missing 'n' in gs2-cb-flag")
|
||||
}
|
||||
opts := OAuthBearerOptions{}
|
||||
if len(authzid) > 0 {
|
||||
if !bytes.HasPrefix(authzid, []byte("a=")) {
|
||||
return a.fail("Invalid response, missing 'a=' in gs2-authzid")
|
||||
}
|
||||
opts.Username = string(bytes.TrimPrefix(authzid, []byte("a=")))
|
||||
}
|
||||
|
||||
// Cut \x01host=...\x01auth=...\x01\x01
|
||||
// into
|
||||
// *empty*
|
||||
// host=...
|
||||
// auth=...
|
||||
// *empty*
|
||||
//
|
||||
// Note that this code does not do a lot of checks to make sure the input
|
||||
// follows the exact format specified by RFC.
|
||||
params := bytes.Split(parts[2], []byte{0x01})
|
||||
for _, p := range params {
|
||||
// Skip empty fields (one at start and end).
|
||||
if len(p) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
pParts := bytes.SplitN(p, []byte{'='}, 2)
|
||||
if len(pParts) != 2 {
|
||||
return a.fail("Invalid response, missing '='")
|
||||
}
|
||||
|
||||
switch string(pParts[0]) {
|
||||
case "host":
|
||||
opts.Host = string(pParts[1])
|
||||
case "port":
|
||||
port, err := strconv.ParseUint(string(pParts[1]), 10, 16)
|
||||
if err != nil {
|
||||
return a.fail("Invalid response, malformed 'port' value")
|
||||
}
|
||||
opts.Port = int(port)
|
||||
case "auth":
|
||||
const prefix = "bearer "
|
||||
strValue := string(pParts[1])
|
||||
// Token type is case-insensitive.
|
||||
if !strings.HasPrefix(strings.ToLower(strValue), prefix) {
|
||||
return a.fail("Unsupported token type")
|
||||
}
|
||||
opts.Token = strValue[len(prefix):]
|
||||
default:
|
||||
return a.fail("Invalid response, unknown parameter: " + string(pParts[0]))
|
||||
}
|
||||
}
|
||||
|
||||
authzErr := a.authenticate(opts)
|
||||
if authzErr != nil {
|
||||
blob, err := json.Marshal(authzErr)
|
||||
if err != nil {
|
||||
panic(err) // wtf
|
||||
}
|
||||
a.failErr = authzErr
|
||||
return blob, false, nil
|
||||
}
|
||||
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
func NewOAuthBearerServer(auth OAuthBearerAuthenticator) *OAuthBearerServer {
|
||||
return &OAuthBearerServer{
|
||||
authenticate: auth,
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue