aboutsummaryrefslogtreecommitdiff
path: root/main.go
diff options
context:
space:
mode:
Diffstat (limited to 'main.go')
-rw-r--r--main.go51
1 files changed, 33 insertions, 18 deletions
diff --git a/main.go b/main.go
index 92c0712..a05bda8 100644
--- a/main.go
+++ b/main.go
@@ -262,6 +262,23 @@ func (server *Server) objectExists(dn string) (bool, error) {
return len(data) > 0, nil
}
+func (server *Server) checkSuffix(dn string, allow_extend bool) (string, error) {
+ suffix := server.config.Suffix
+ if len(dn) < len(suffix) {
+ if dn != suffix[-len(dn):] || !allow_extend {
+ return suffix, fmt.Errorf(
+ "Only handling stuff under DN %s", suffix)
+ }
+ return suffix, nil
+ } else {
+ if dn[len(dn)-len(suffix):] != suffix {
+ return suffix, fmt.Errorf(
+ "Only handling stuff under DN %s", suffix)
+ }
+ return dn, nil
+ }
+}
+
func (server *Server) handleBind(s ldap.UserState, w ldap.ResponseWriter, m *ldap.Message) {
state := s.(*State)
r := m.GetBindRequest()
@@ -320,25 +337,14 @@ func (server *Server) handleSearchInternal(state *State, w ldap.ResponseWriter,
log.Printf("Request TimeLimit=%d", r.TimeLimit().Int())
// TODO check authorizations
- baseObject := dnToConsul(string(r.BaseObject()))
- minimalBaseObject := dnToConsul(server.config.Suffix)
-
- if len(baseObject) <= len(minimalBaseObject) {
- if baseObject != minimalBaseObject[:len(baseObject)] {
- return ldap.LDAPResultInvalidDNSyntax, fmt.Errorf(
- "Only handling search results under DN=%s",
- server.config.Suffix)
- }
- baseObject = minimalBaseObject
- } else {
- if baseObject[:len(minimalBaseObject)] != minimalBaseObject {
- return ldap.LDAPResultInvalidDNSyntax, fmt.Errorf(
- "Only handling search results under DN=%s",
- server.config.Suffix)
- }
+
+ baseObject, err := server.checkSuffix(string(r.BaseObject()), true)
+ if err != nil {
+ return ldap.LDAPResultInvalidDNSyntax, err
}
+ basePath := dnToConsul(baseObject) + "/"
- data, _, err := server.kv.List(baseObject + "/", nil)
+ data, _, err := server.kv.List(basePath, nil)
if err != nil {
return ldap.LDAPResultOperationsError, err
}
@@ -347,7 +353,7 @@ func (server *Server) handleSearchInternal(state *State, w ldap.ResponseWriter,
if err != nil {
return ldap.LDAPResultOperationsError, err
}
- log.Printf("in %s: %#v", baseObject + "/", data)
+ log.Printf("in %s: %#v", basePath, data)
log.Printf("%#v", entries)
for dn, entry := range entries {
@@ -461,6 +467,11 @@ func (server *Server) handleAdd(s ldap.UserState, w ldap.ResponseWriter, m *ldap
func (server *Server) handleAddInternal(state *State, r *message.AddRequest) (int, error) {
dn := string(r.Entry())
+ _, err := server.checkSuffix(dn, false)
+ if err != nil {
+ return ldap.LDAPResultInvalidDNSyntax, err
+ }
+
exists, err := server.objectExists(dn)
if err != nil {
return ldap.LDAPResultOperationsError, err
@@ -486,6 +497,10 @@ func (server *Server) handleAddInternal(state *State, r *message.AddRequest) (in
if strings.EqualFold(key, "member") {
members = vals_str
for _, member := range members {
+ _, err := server.checkSuffix(member, false)
+ if err != nil {
+ return ldap.LDAPResultInvalidDNSyntax, err
+ }
exists, err = server.objectExists(member)
if err != nil {
return ldap.LDAPResultOperationsError, err