aboutsummaryrefslogblamecommitdiff
path: root/read.go
blob: 7887c934bb66f41a1af30a495f4e4a2cef361228 (plain) (tree)
1
2
3
4
5
6
7
8
9





                 

                                
                                                  

 







                                                                              

                                                                                         




                                                                             
                                                               


























                                                                              















                                                                                                   


                                                    
                                                           























                                                                                   
                                                  






                                                              











                                                                                              


                                                                            



                                                                                                                        
 




                                                                       
                                                     
                                                             



                                                                         
 
                                                                                   


                                                                                                                                        

                                                              



                                                          




                                                              
 
                                                                  







                                                          
 
                                                          
                                            

                                        





                                                                         
                                                             


                                        



















                                                                                         



                                                                                                                   




































































                                                                                                  
                                                                                










                                                                                      
package main

import (
	"fmt"
	"strings"

	ldap "bottin/ldapserver"

	message "github.com/lor00x/goldap/message"
)

// Generic read utility functions ----------

func (server *Server) getAttribute(dn string, attr string) ([]string, error) {
	path, err := dnToConsul(dn)
	if err != nil {
		return nil, err
	}

	// List all attributes of the object, this is needed because the attribute we are
	// looking for can exist with different cases than the one specified here
	pairs, _, err := server.kv.List(path+"/attribute=", &server.readOpts)
	if err != nil {
		return nil, err
	}

	// Collect values for the attribute, case-insensitively
	values := []string{}
	for _, pair := range pairs {
		if strings.EqualFold(pair.Key, path+"/attribute="+attr) {
			newVals, err := parseValue(pair.Value)
			if err != nil {
				return nil, err
			}
			values = append(values, newVals...)
		}
	}

	return values, nil
}

func (server *Server) objectExists(dn string) (bool, error) {
	prefix, err := dnToConsul(dn)
	if err != nil {
		return false, err
	}

	data, _, err := server.kv.List(prefix+"/attribute=", &server.readOpts)
	if err != nil {
		return false, err
	}
	return len(data) > 0, nil
}

// Compare request -------------------------

func (server *Server) handleCompare(s ldap.UserState, w ldap.ResponseWriter, m *ldap.Message) {
	state := s.(*State)
	r := m.GetCompareRequest()

	code, err := server.handleCompareInternal(state, &r)

	res := ldap.NewResponse(code)
	if err != nil {
		res.SetDiagnosticMessage(err.Error())
	}
	w.Write(message.CompareResponse(res))
}

func (server *Server) handleCompareInternal(state *State, r *message.CompareRequest) (int, error) {
	attr := string(r.Ava().AttributeDesc())
	expected := string(r.Ava().AssertionValue())

	dn, err := server.checkDN(string(r.Entry()), false)
	if err != nil {
		return ldap.LDAPResultInvalidDNSyntax, err
	}

	// Check permissions
	if !server.config.Acl.Check(&state.login, dn, "read", []string{attr}) {
		return ldap.LDAPResultInsufficientAccessRights, nil
	}

	// Do query
	exists, err := server.objectExists(dn)
	if err != nil {
		return ldap.LDAPResultOperationsError, err
	}
	if !exists {
		return ldap.LDAPResultNoSuchObject, fmt.Errorf("Not found: %s", dn)
	}

	values, err := server.getAttribute(dn, attr)
	if err != nil {
		return ldap.LDAPResultOperationsError, err
	}

	for _, v := range values {
		if valueMatch(attr, v, expected) {
			return ldap.LDAPResultCompareTrue, nil
		}
	}

	return ldap.LDAPResultCompareFalse, nil
}

// Search request -------------------------

func (server *Server) handleSearch(s ldap.UserState, w ldap.ResponseWriter, m *ldap.Message) {
	state := s.(*State)
	r := m.GetSearchRequest()

	code, err := server.handleSearchInternal(state, w, &r)

	res := ldap.NewResponse(code)
	if err != nil {
		res.SetDiagnosticMessage(err.Error())
	}
	if code != ldap.LDAPResultSuccess {
		server.logger.Printf("Failed to do search %#v (%s)", r, err)
	}
	w.Write(message.SearchResultDone(res))
}

func (server *Server) handleSearchInternal(state *State, w ldap.ResponseWriter, r *message.SearchRequest) (int, error) {

	baseObject, err := server.checkDN(string(r.BaseObject()), true)
	if err != nil {
		return ldap.LDAPResultInvalidDNSyntax, err
	}

	server.logger.Tracef("-- SEARCH REQUEST: --")
	server.logger.Tracef("Request BaseDn=%s", baseObject)
	server.logger.Tracef("Request Filter=%s", r.Filter())
	server.logger.Tracef("Request FilterString=%s", r.FilterString())
	server.logger.Tracef("Request Attributes=%s", r.Attributes())
	server.logger.Tracef("Request TimeLimit=%d", r.TimeLimit().Int())

	if !server.config.Acl.Check(&state.login, "read", baseObject, []string{}) {
		return ldap.LDAPResultInsufficientAccessRights, fmt.Errorf("Please specify a base object on which you have read rights")
	}

	baseObjectLevel := len(strings.Split(baseObject, ","))

	basePath, err := dnToConsul(baseObject)
	if err != nil {
		return ldap.LDAPResultInvalidDNSyntax, err
	}
	if r.Scope() == message.SearchRequestScopeBaseObject {
		basePath += "/attribute="
	} else {
		basePath += "/"
	}

	data, _, err := server.kv.List(basePath, &server.readOpts)
	if err != nil {
		return ldap.LDAPResultOperationsError, err
	}

	entries, err := parseConsulResult(data)
	if err != nil {
		return ldap.LDAPResultOperationsError, err
	}

	server.logger.Tracef("in %s: %#v", basePath, data)
	server.logger.Tracef("%#v", entries)

	for dn, entry := range entries {
		if r.Scope() == message.SearchRequestScopeBaseObject {
			if dn != baseObject {
				continue
			}
		} else if r.Scope() == message.SearchRequestSingleLevel {
			objectLevel := len(strings.Split(dn, ","))
			if objectLevel != baseObjectLevel+1 {
				continue
			}
		}
		// Filter out if we don't match requested filter
		matched, err := applyFilter(entry, r.Filter())
		if err != nil {
			return ldap.LDAPResultUnwillingToPerform, err
		}
		if !matched {
			continue
		}

		// Filter out if user is not allowed to read this
		if !server.config.Acl.Check(&state.login, "read", dn, []string{}) {
			continue
		}

		e := ldap.NewSearchResultEntry(dn)
		for attr, val := range entry {
			// If attribute is not in request, exclude it from returned entry
			if len(r.Attributes()) > 0 {
				found := false
				for _, requested := range r.Attributes() {
					if string(requested) == "1.1" && len(r.Attributes()) == 1 {
						break
					}
					if string(requested) == "*" || strings.EqualFold(string(requested), attr) {
						found = true
						break
					}
				}
				if !found {
					continue
				}
			}
			// If we are not allowed to read attribute, exclude it from returned entry
			if !server.config.Acl.Check(&state.login, "read", dn, []string{attr}) {
				continue
			}
			// Send result
			for _, v := range val {
				e.AddAttribute(message.AttributeDescription(attr),
					message.AttributeValue(v))
			}
		}
		w.Write(e)
	}

	return ldap.LDAPResultSuccess, nil
}

func applyFilter(entry Entry, filter message.Filter) (bool, error) {
	if fAnd, ok := filter.(message.FilterAnd); ok {
		for _, cond := range fAnd {
			res, err := applyFilter(entry, cond)
			if err != nil {
				return false, err
			}
			if !res {
				return false, nil
			}
		}
		return true, nil
	} else if fOr, ok := filter.(message.FilterOr); ok {
		for _, cond := range fOr {
			res, err := applyFilter(entry, cond)
			if err != nil {
				return false, err
			}
			if res {
				return true, nil
			}
		}
		return false, nil
	} else if fNot, ok := filter.(message.FilterNot); ok {
		res, err := applyFilter(entry, fNot.Filter)
		if err != nil {
			return false, err
		}
		return !res, nil
	} else if fPresent, ok := filter.(message.FilterPresent); ok {
		what := string(fPresent)
		// Case insensitive search
		for desc, values := range entry {
			if strings.EqualFold(what, desc) {
				return len(values) > 0, nil
			}
		}
		return false, nil
	} else if fEquality, ok := filter.(message.FilterEqualityMatch); ok {
		desc := string(fEquality.AttributeDesc())
		target := string(fEquality.AssertionValue())
		// Case insensitive attribute search
		for entry_desc, value := range entry {
			if strings.EqualFold(entry_desc, desc) {
				for _, val := range value {
					if valueMatch(entry_desc, val, target) {
						return true, nil
					}
				}
				return false, nil
			}
		}
		return false, nil
	} else {
		return false, fmt.Errorf("Unsupported filter: %#v %T", filter, filter)
	}
}