diff options
-rw-r--r-- | main.go | 51 |
1 files changed, 33 insertions, 18 deletions
@@ -262,6 +262,23 @@ func (server *Server) objectExists(dn string) (bool, error) { return len(data) > 0, nil } +func (server *Server) checkSuffix(dn string, allow_extend bool) (string, error) { + suffix := server.config.Suffix + if len(dn) < len(suffix) { + if dn != suffix[-len(dn):] || !allow_extend { + return suffix, fmt.Errorf( + "Only handling stuff under DN %s", suffix) + } + return suffix, nil + } else { + if dn[len(dn)-len(suffix):] != suffix { + return suffix, fmt.Errorf( + "Only handling stuff under DN %s", suffix) + } + return dn, nil + } +} + func (server *Server) handleBind(s ldap.UserState, w ldap.ResponseWriter, m *ldap.Message) { state := s.(*State) r := m.GetBindRequest() @@ -320,25 +337,14 @@ func (server *Server) handleSearchInternal(state *State, w ldap.ResponseWriter, log.Printf("Request TimeLimit=%d", r.TimeLimit().Int()) // TODO check authorizations - baseObject := dnToConsul(string(r.BaseObject())) - minimalBaseObject := dnToConsul(server.config.Suffix) - - if len(baseObject) <= len(minimalBaseObject) { - if baseObject != minimalBaseObject[:len(baseObject)] { - return ldap.LDAPResultInvalidDNSyntax, fmt.Errorf( - "Only handling search results under DN=%s", - server.config.Suffix) - } - baseObject = minimalBaseObject - } else { - if baseObject[:len(minimalBaseObject)] != minimalBaseObject { - return ldap.LDAPResultInvalidDNSyntax, fmt.Errorf( - "Only handling search results under DN=%s", - server.config.Suffix) - } + + baseObject, err := server.checkSuffix(string(r.BaseObject()), true) + if err != nil { + return ldap.LDAPResultInvalidDNSyntax, err } + basePath := dnToConsul(baseObject) + "/" - data, _, err := server.kv.List(baseObject + "/", nil) + data, _, err := server.kv.List(basePath, nil) if err != nil { return ldap.LDAPResultOperationsError, err } @@ -347,7 +353,7 @@ func (server *Server) handleSearchInternal(state *State, w ldap.ResponseWriter, if err != nil { return ldap.LDAPResultOperationsError, err } - log.Printf("in %s: %#v", baseObject + "/", data) + log.Printf("in %s: %#v", basePath, data) log.Printf("%#v", entries) for dn, entry := range entries { @@ -461,6 +467,11 @@ func (server *Server) handleAdd(s ldap.UserState, w ldap.ResponseWriter, m *ldap func (server *Server) handleAddInternal(state *State, r *message.AddRequest) (int, error) { dn := string(r.Entry()) + _, err := server.checkSuffix(dn, false) + if err != nil { + return ldap.LDAPResultInvalidDNSyntax, err + } + exists, err := server.objectExists(dn) if err != nil { return ldap.LDAPResultOperationsError, err @@ -486,6 +497,10 @@ func (server *Server) handleAddInternal(state *State, r *message.AddRequest) (in if strings.EqualFold(key, "member") { members = vals_str for _, member := range members { + _, err := server.checkSuffix(member, false) + if err != nil { + return ldap.LDAPResultInvalidDNSyntax, err + } exists, err = server.objectExists(member) if err != nil { return ldap.LDAPResultOperationsError, err |