aboutsummaryrefslogblamecommitdiff
path: root/read.go
blob: 04106c59f6697e4ea7cdc2ddb44d9f4ab43380b0 (plain) (tree)



























































































































































































































                                                                                                                                        
package main

import (
	"fmt"
	"strings"

	ldap "./ldapserver"
	message "github.com/vjeantet/goldap/message"
)


// Compare request -------------------------

func (server *Server) handleCompare(s ldap.UserState, w ldap.ResponseWriter, m *ldap.Message) {
	state := s.(*State)
	r := m.GetCompareRequest()

	code, err := server.handleCompareInternal(state, &r)

	res := ldap.NewResponse(code)
	if err != nil {
		res.SetDiagnosticMessage(err.Error())
	}
	w.Write(message.CompareResponse(res))
}

func (server *Server) handleCompareInternal(state *State, r *message.CompareRequest) (int, error) {
	dn := string(r.Entry())
	attr := string(r.Ava().AttributeDesc())
	expected := string(r.Ava().AssertionValue())

	_, err := server.checkSuffix(dn, false)
	if err != nil {
		return ldap.LDAPResultInvalidDNSyntax, err
	}

	// Check permissions
	if !server.config.Acl.Check(&state.login, dn, "read", []string{attr}) {
		return ldap.LDAPResultInsufficientAccessRights, nil
	}

	// Do query
	exists, err := server.objectExists(dn)
	if err != nil {
		return ldap.LDAPResultOperationsError, err
	}
	if !exists {
		return ldap.LDAPResultNoSuchObject, fmt.Errorf("Not found: %s", dn)
	}

	values, err := server.getAttribute(dn, attr)
	if err != nil {
		return ldap.LDAPResultOperationsError, err
	}

	for _, v := range values {
		if v == expected {
			return ldap.LDAPResultCompareTrue, nil
		}
	}

	return ldap.LDAPResultCompareFalse, nil
}


// Search request -------------------------

func (server *Server) handleSearch(s ldap.UserState, w ldap.ResponseWriter, m *ldap.Message) {
	state := s.(*State)
	r := m.GetSearchRequest()

	code, err := server.handleSearchInternal(state, w, &r)

	res := ldap.NewResponse(code)
	if err != nil {
		res.SetDiagnosticMessage(err.Error())
	}
	w.Write(message.SearchResultDone(res))
}

func (server *Server) handleSearchInternal(state *State, w ldap.ResponseWriter, r *message.SearchRequest) (int, error) {
	if DEBUG {
		server.logger.Printf("-- SEARCH REQUEST: --")
		server.logger.Printf("Request BaseDn=%s", r.BaseObject())
		server.logger.Printf("Request Filter=%s", r.Filter())
		server.logger.Printf("Request FilterString=%s", r.FilterString())
		server.logger.Printf("Request Attributes=%s", r.Attributes())
		server.logger.Printf("Request TimeLimit=%d", r.TimeLimit().Int())
	}

	if !server.config.Acl.Check(&state.login, "read", string(r.BaseObject()), []string{}) {
		return ldap.LDAPResultInsufficientAccessRights, fmt.Errorf("Please specify a base object on which you have read rights")
	}

	baseObject, err := server.checkSuffix(string(r.BaseObject()), true)
	if err != nil {
		return ldap.LDAPResultInvalidDNSyntax, err
	}
	basePath, err := dnToConsul(baseObject)
	if err != nil {
		return ldap.LDAPResultInvalidDNSyntax, err
	}

	data, _, err := server.kv.List(basePath+"/", nil)
	if err != nil {
		return ldap.LDAPResultOperationsError, err
	}

	entries, err := parseConsulResult(data)
	if err != nil {
		return ldap.LDAPResultOperationsError, err
	}
	if DEBUG {
		server.logger.Printf("in %s: %#v", basePath+"/", data)
		server.logger.Printf("%#v", entries)
	}

	for dn, entry := range entries {
		// Filter out if we don't match requested filter
		matched, err := applyFilter(entry, r.Filter())
		if err != nil {
			return ldap.LDAPResultUnwillingToPerform, err
		}
		if !matched {
			continue
		}

		// Filter out if user is not allowed to read this
		if !server.config.Acl.Check(&state.login, "read", dn, []string{}) {
			continue
		}

		e := ldap.NewSearchResultEntry(dn)
		for attr, val := range entry {
			// If attribute is not in request, exclude it from returned entry
			if len(r.Attributes()) > 0 {
				found := false
				for _, requested := range r.Attributes() {
					if strings.EqualFold(string(requested), attr) {
						found = true
						break
					}
				}
				if !found {
					continue
				}
			}
			// If we are not allowed to read attribute, exclude it from returned entry
			if !server.config.Acl.Check(&state.login, "read", dn, []string{attr}) {
				continue
			}
			// Send result
			for _, v := range val {
				e.AddAttribute(message.AttributeDescription(attr),
					message.AttributeValue(v))
			}
		}
		w.Write(e)
	}

	return ldap.LDAPResultSuccess, nil
}

func applyFilter(entry Entry, filter message.Filter) (bool, error) {
	if fAnd, ok := filter.(message.FilterAnd); ok {
		for _, cond := range fAnd {
			res, err := applyFilter(entry, cond)
			if err != nil {
				return false, err
			}
			if !res {
				return false, nil
			}
		}
		return true, nil
	} else if fOr, ok := filter.(message.FilterOr); ok {
		for _, cond := range fOr {
			res, err := applyFilter(entry, cond)
			if err != nil {
				return false, err
			}
			if res {
				return true, nil
			}
		}
		return false, nil
	} else if fNot, ok := filter.(message.FilterNot); ok {
		res, err := applyFilter(entry, fNot.Filter)
		if err != nil {
			return false, err
		}
		return !res, nil
	} else if fPresent, ok := filter.(message.FilterPresent); ok {
		what := string(fPresent)
		// Case insensitive search
		for desc, values := range entry {
			if strings.EqualFold(what, desc) {
				return len(values) > 0, nil
			}
		}
		return false, nil
	} else if fEquality, ok := filter.(message.FilterEqualityMatch); ok {
		desc := string(fEquality.AttributeDesc())
		target := string(fEquality.AssertionValue())
		// Case insensitive attribute search
		for entry_desc, value := range entry {
			if strings.EqualFold(entry_desc, desc) {
				for _, val := range value {
					if val == target {
						return true, nil
					}
				}
				return false, nil
			}
		}
		return false, nil
	} else {
		return false, fmt.Errorf("Unsupported filter: %#v %T", filter, filter)
	}
}