aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--main.go135
1 files changed, 110 insertions, 25 deletions
diff --git a/main.go b/main.go
index aeeb421..92c0712 100644
--- a/main.go
+++ b/main.go
@@ -46,11 +46,27 @@ func consulToDN(pair *consul.KVPair) (string, string, []byte) {
return dn, "", nil
}
+func parseValue(value []byte) ([]string, error) {
+ val := []string{}
+ err := json.Unmarshal(value, &val)
+ if err == nil {
+ return val, nil
+ }
+
+ val2 := ""
+ err = json.Unmarshal(value, &val2)
+ if err == nil {
+ return []string{val2}, nil
+ }
+
+ return nil, fmt.Errorf("Not a string or list of strings: %s", value)
+}
+
func parseConsulResult(data []*consul.KVPair) (map[string]Entry, error) {
aggregator := map[string]Entry{}
for _, kv := range data {
- log.Printf("%s %s", kv.Key, string(kv.Value))
+ log.Printf("(parseConsulResult) %s %s", kv.Key, string(kv.Value))
dn, attr, val := consulToDN(kv)
if attr == "" || val == nil {
continue
@@ -58,8 +74,7 @@ func parseConsulResult(data []*consul.KVPair) (map[string]Entry, error) {
if _, exists := aggregator[dn]; !exists {
aggregator[dn] = Entry{}
}
- var value []string
- err := json.Unmarshal(val, &value)
+ value, err := parseValue(val)
if err != nil {
return nil, err
}
@@ -224,6 +239,29 @@ func (server *Server) addElements(dn string, attrs Entry) error {
return nil
}
+func (server *Server) getAttribute(dn string, attr string) ([]string, error) {
+ pair, _, err := server.kv.Get(dnToConsul(dn) + "/attribute=" + attr, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ if pair == nil {
+ return nil, nil
+ }
+
+ return parseValue(pair.Value)
+}
+
+func (server *Server) objectExists(dn string) (bool, error) {
+ prefix := dnToConsul(dn) + "/"
+
+ data, _, err := server.kv.List(prefix, nil)
+ if err != nil {
+ return false, err
+ }
+ return len(data) > 0, nil
+}
+
func (server *Server) handleBind(s ldap.UserState, w ldap.ResponseWriter, m *ldap.Message) {
state := s.(*State)
r := m.GetBindRequest()
@@ -240,28 +278,23 @@ func (server *Server) handleBind(s ldap.UserState, w ldap.ResponseWriter, m *lda
func (server *Server) handleBindInternal(state *State, r *message.BindRequest) (int, error) {
- pair, _, err := server.kv.Get(dnToConsul(string(r.Name()))+"/attribute=userpassword", nil)
+ passwd, err := server.getAttribute(string(r.Name()), "userpassword")
if err != nil {
return ldap.LDAPResultOperationsError, err
}
- if pair == nil {
+ if passwd == nil {
return ldap.LDAPResultNoSuchObject, nil
}
- hash := ""
- err = json.Unmarshal(pair.Value, &hash)
- if err != nil {
- return ldap.LDAPResultOperationsError, err
- }
-
- valid := SSHAMatches(hash, []byte(r.AuthenticationSimple()))
- if valid {
- state.bindDn = string(r.Name())
- return ldap.LDAPResultSuccess, nil
- } else {
- return ldap.LDAPResultInvalidCredentials, nil
+ for _, hash := range passwd {
+ valid := SSHAMatches(hash, []byte(r.AuthenticationSimple()))
+ if valid {
+ state.bindDn = string(r.Name())
+ return ldap.LDAPResultSuccess, nil
+ }
}
+ return ldap.LDAPResultInvalidCredentials, nil
}
func (server *Server) handleSearch(s ldap.UserState, w ldap.ResponseWriter, m *ldap.Message) {
@@ -287,9 +320,25 @@ func (server *Server) handleSearchInternal(state *State, w ldap.ResponseWriter,
log.Printf("Request TimeLimit=%d", r.TimeLimit().Int())
// TODO check authorizations
+ baseObject := dnToConsul(string(r.BaseObject()))
+ minimalBaseObject := dnToConsul(server.config.Suffix)
+
+ if len(baseObject) <= len(minimalBaseObject) {
+ if baseObject != minimalBaseObject[:len(baseObject)] {
+ return ldap.LDAPResultInvalidDNSyntax, fmt.Errorf(
+ "Only handling search results under DN=%s",
+ server.config.Suffix)
+ }
+ baseObject = minimalBaseObject
+ } else {
+ if baseObject[:len(minimalBaseObject)] != minimalBaseObject {
+ return ldap.LDAPResultInvalidDNSyntax, fmt.Errorf(
+ "Only handling search results under DN=%s",
+ server.config.Suffix)
+ }
+ }
- basePath := dnToConsul(string(r.BaseObject())) + "/"
- data, _, err := server.kv.List(basePath, nil)
+ data, _, err := server.kv.List(baseObject + "/", nil)
if err != nil {
return ldap.LDAPResultOperationsError, err
}
@@ -298,7 +347,7 @@ func (server *Server) handleSearchInternal(state *State, w ldap.ResponseWriter,
if err != nil {
return ldap.LDAPResultOperationsError, err
}
- log.Printf("in %s: %#v", basePath, data)
+ log.Printf("in %s: %#v", baseObject + "/", data)
log.Printf("%#v", entries)
for dn, entry := range entries {
@@ -412,25 +461,42 @@ func (server *Server) handleAdd(s ldap.UserState, w ldap.ResponseWriter, m *ldap
func (server *Server) handleAddInternal(state *State, r *message.AddRequest) (int, error) {
dn := string(r.Entry())
- prefix := dnToConsul(dn) + "/"
-
- data, _, err := server.kv.List(prefix, nil)
+ exists, err := server.objectExists(dn)
if err != nil {
return ldap.LDAPResultOperationsError, err
}
- if len(data) > 0 {
+ if exists {
return ldap.LDAPResultEntryAlreadyExists, nil
}
// TODO check permissions
+ var members []string = nil
entry := Entry{}
for _, attribute := range r.Attributes() {
key := string(attribute.Type_())
+ if strings.EqualFold(key, "memberOf") {
+ return ldap.LDAPResultObjectClassViolation, fmt.Errorf(
+ "memberOf cannot be defined directly, membership must be specified in the group itself")
+ }
vals_str := []string{}
for _, val := range attribute.Vals() {
vals_str = append(vals_str, string(val))
}
+ if strings.EqualFold(key, "member") {
+ members = vals_str
+ for _, member := range members {
+ exists, err = server.objectExists(member)
+ if err != nil {
+ return ldap.LDAPResultOperationsError, err
+ }
+ if !exists {
+ return ldap.LDAPResultNoSuchObject, fmt.Errorf(
+ "Cannot add %s to members, it does not exist!",
+ member)
+ }
+ }
+ }
entry[key] = vals_str
}
@@ -439,6 +505,25 @@ func (server *Server) handleAddInternal(state *State, r *message.AddRequest) (in
return ldap.LDAPResultOperationsError, err
}
+ if members != nil {
+ for _, member := range members {
+ memberGroups, err := server.getAttribute(member, "memberOf")
+ if err != nil {
+ return ldap.LDAPResultOperationsError, err
+ }
+ if memberGroups == nil {
+ memberGroups = []string{}
+ }
+
+ memberGroups = append(memberGroups, dn)
+ err = server.addElements(member, Entry{
+ "memberOf": memberGroups,
+ })
+ if err != nil {
+ return ldap.LDAPResultOperationsError, err
+ }
+ }
+ }
+
return ldap.LDAPResultSuccess, nil
}
-