1
0
Fork 0
forked from External/ergo

improve maintainability and license compliance

0. Maximum parity with upstream code
1. Added Apache-required modification notices
2. Added Apache license
This commit is contained in:
Shivaram Lingamneni 2020-02-11 16:09:43 -05:00
parent c13597f807
commit 0c2d8adeac
7 changed files with 551 additions and 250 deletions

View file

@ -30,12 +30,8 @@
package ldap
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"strings"
ldap "github.com/go-ldap/ldap/v3"
@ -43,38 +39,48 @@ import (
)
var (
ErrCouldNotFindUser = errors.New("No such user")
ErrUserNotInRequiredGroup = errors.New("User is not a member of any required groups")
ErrInvalidCredentials = errors.New("Invalid credentials")
)
func CheckLDAPPassphrase(config LDAPConfig, accountName, passphrase string, log *logger.Manager) (err error) {
// equivalent of Grafana's `Server`, but unexported
type serverConn struct {
Config *ServerConfig
Connection *ldap.Conn
log *logger.Manager
}
func CheckLDAPPassphrase(config ServerConfig, accountName, passphrase string, log *logger.Manager) (err error) {
defer func() {
if err != nil {
log.Debug("ldap", "failed passphrase check", err.Error())
}
}()
l, err := dial(&config)
server := serverConn{
Config: &config,
log: log,
}
err = server.Dial()
if err != nil {
return
}
defer l.Close()
defer server.Close()
l.SetTimeout(config.Timeout)
server.Connection.SetTimeout(config.Timeout)
passphraseChecked := false
if config.shouldSingleBind() {
if server.shouldSingleBind() {
log.Debug("ldap", "attempting single bind to", accountName)
err = l.Bind(config.singleBindDN(accountName), passphrase)
err = server.userBind(server.singleBindDN(accountName), passphrase)
passphraseChecked = (err == nil)
} else if config.shouldAdminBind() {
} else if server.shouldAdminBind() {
log.Debug("ldap", "attempting admin bind to", config.BindDN)
err = l.Bind(config.BindDN, config.BindPassword)
err = server.userBind(config.BindDN, config.BindPassword)
} else {
log.Debug("ldap", "attempting unauthenticated bind")
err = l.UnauthenticatedBind(config.BindDN)
err = server.Connection.UnauthenticatedBind(config.BindDN)
}
if err != nil {
@ -85,7 +91,7 @@ func CheckLDAPPassphrase(config LDAPConfig, accountName, passphrase string, log
return nil
}
users, err := lookupUsers(l, &config, accountName)
users, err := server.users([]string{accountName})
if err != nil {
log.Debug("ldap", "failed user lookup")
return err
@ -99,225 +105,46 @@ func CheckLDAPPassphrase(config LDAPConfig, accountName, passphrase string, log
log.Debug("ldap", "looked up user", user.DN)
err = validateGroupMembership(l, &config, user, log)
err = server.validateGroupMembership(user)
if err != nil {
return err
}
if !passphraseChecked {
// Authenticate user
log.Debug("ldap", "rebinding", user.DN)
err = l.Bind(user.DN, passphrase)
if err != nil {
log.Debug("ldap", "failed rebind", err.Error())
if ldapErr, ok := err.(*ldap.Error); ok {
if ldapErr.ResultCode == 49 {
return ErrInvalidCredentials
}
}
}
return err
err = server.userBind(user.DN, passphrase)
}
return nil
return err
}
func dial(config *LDAPConfig) (conn *ldap.Conn, err error) {
var certPool *x509.CertPool
if config.RootCACert != "" {
certPool = x509.NewCertPool()
for _, caCertFile := range strings.Split(config.RootCACert, " ") {
pem, err := ioutil.ReadFile(caCertFile)
if err != nil {
return nil, err
}
if !certPool.AppendCertsFromPEM(pem) {
return nil, errors.New("Failed to append CA certificate " + caCertFile)
}
}
func (server *serverConn) validateGroupMembership(user *ldap.Entry) (err error) {
if len(server.Config.RequireGroups) == 0 {
return
}
var clientCert tls.Certificate
if config.ClientCert != "" && config.ClientKey != "" {
clientCert, err = tls.LoadX509KeyPair(config.ClientCert, config.ClientKey)
if err != nil {
return
}
}
for _, host := range strings.Split(config.Host, " ") {
address := fmt.Sprintf("%s:%d", host, config.Port)
if config.UseSSL {
tlsCfg := &tls.Config{
InsecureSkipVerify: config.SkipTLSVerify,
ServerName: host,
RootCAs: certPool,
}
if len(clientCert.Certificate) > 0 {
tlsCfg.Certificates = append(tlsCfg.Certificates, clientCert)
}
if config.StartTLS {
conn, err = ldap.Dial("tcp", address)
if err == nil {
if err = conn.StartTLS(tlsCfg); err == nil {
return
}
}
} else {
conn, err = ldap.DialTLS("tcp", address, tlsCfg)
}
} else {
conn, err = ldap.Dial("tcp", address)
}
if err == nil {
return
}
var memberOf []string
memberOf, err = server.getMemberOf(user)
if err != nil {
server.log.Debug("ldap", "could not retrieve group memberships", err.Error())
return
}
return
}
func validateGroupMembership(conn *ldap.Conn, config *LDAPConfig, user *ldap.Entry, log *logger.Manager) (err error) {
if len(config.RequireGroups) != 0 {
var memberOf []string
memberOf, err = getMemberOf(conn, config, user)
if err != nil {
log.Debug("ldap", "could not retrieve group memberships", err.Error())
return
}
log.Debug("ldap", fmt.Sprintf("found group memberships: %v", memberOf))
foundGroup := false
for _, inGroup := range memberOf {
for _, acceptableGroup := range config.RequireGroups {
if inGroup == acceptableGroup {
foundGroup = true
break
}
}
if foundGroup {
server.log.Debug("ldap", fmt.Sprintf("found group memberships: %v", memberOf))
foundGroup := false
for _, inGroup := range memberOf {
for _, acceptableGroup := range server.Config.RequireGroups {
if inGroup == acceptableGroup {
foundGroup = true
break
}
}
if !foundGroup {
return ErrUserNotInRequiredGroup
}
}
return nil
}
func lookupUsers(conn *ldap.Conn, config *LDAPConfig, accountName string) (results []*ldap.Entry, err error) {
var result *ldap.SearchResult
for _, base := range config.SearchBaseDNs {
result, err = conn.Search(
getSearchRequest(config, base, accountName),
)
if err != nil {
return nil, err
} else if len(result.Entries) > 0 {
return result.Entries, nil
}
}
return nil, nil
}
// getSearchRequest returns LDAP search request for users
func getSearchRequest(
config *LDAPConfig,
base string,
accountName string,
) *ldap.SearchRequest {
var attributes []string
if config.MemberOfAttribute != "" {
attributes = []string{config.MemberOfAttribute}
}
query := strings.Replace(
config.SearchFilter,
"%s", ldap.EscapeFilter(accountName),
-1,
)
return &ldap.SearchRequest{
BaseDN: base,
Scope: ldap.ScopeWholeSubtree,
DerefAliases: ldap.NeverDerefAliases,
Attributes: attributes,
Filter: query,
}
}
// getMemberOf finds memberOf property or request it
func getMemberOf(conn *ldap.Conn, config *LDAPConfig, result *ldap.Entry) (
[]string, error,
) {
if config.GroupSearchFilter == "" {
memberOf := getArrayAttribute(config.MemberOfAttribute, result)
return memberOf, nil
}
memberOf, err := requestMemberOf(conn, config, result)
if err != nil {
return nil, err
}
return memberOf, nil
}
// requestMemberOf use this function when POSIX LDAP
// schema does not support memberOf, so it manually search the groups
func requestMemberOf(conn *ldap.Conn, config *LDAPConfig, entry *ldap.Entry) ([]string, error) {
var memberOf []string
for _, groupSearchBase := range config.GroupSearchBaseDNs {
var filterReplace string
if config.GroupSearchFilterUserAttribute == "" {
filterReplace = "cn"
} else {
filterReplace = getAttribute(
config.GroupSearchFilterUserAttribute,
entry,
)
}
filter := strings.Replace(
config.GroupSearchFilter, "%s",
ldap.EscapeFilter(filterReplace),
-1,
)
// support old way of reading settings
groupIDAttribute := config.MemberOfAttribute
// but prefer dn attribute if default settings are used
if groupIDAttribute == "" || groupIDAttribute == "memberOf" {
groupIDAttribute = "dn"
}
groupSearchReq := ldap.SearchRequest{
BaseDN: groupSearchBase,
Scope: ldap.ScopeWholeSubtree,
DerefAliases: ldap.NeverDerefAliases,
Attributes: []string{groupIDAttribute},
Filter: filter,
}
groupSearchResult, err := conn.Search(&groupSearchReq)
if err != nil {
return nil, err
}
if len(groupSearchResult.Entries) > 0 {
for _, group := range groupSearchResult.Entries {
memberOf = append(
memberOf,
getAttribute(groupIDAttribute, group),
)
}
if foundGroup {
break
}
}
return memberOf, nil
if foundGroup {
return nil
} else {
return ErrUserNotInRequiredGroup
}
}