package main import ( "encoding/base64" "encoding/json" "fmt" "log" "math/rand" "os" "os/signal" "strings" "syscall" ldap "./ldapserver" consul "github.com/hashicorp/consul/api" message "github.com/vjeantet/goldap/message" ) func dnToConsul(dn string) string { rdns := strings.Split(dn, ",") // Reverse rdns for i, j := 0, len(rdns)-1; i < j; i, j = i+1, j-1 { rdns[i], rdns[j] = rdns[j], rdns[i] } return strings.Join(rdns, "/") } func consulToDN(pair *consul.KVPair) (string, string, []byte) { path := strings.Split(pair.Key, "/") dn := "" for _, cpath := range path { if cpath == "" { continue } kv := strings.Split(cpath, "=") if len(kv) == 2 && kv[0] == "attribute" { return dn, kv[1], pair.Value } if dn != "" { dn = "," + dn } dn = cpath + dn } return dn, "", nil } func parseValue(value []byte) ([]string, error) { val := []string{} err := json.Unmarshal(value, &val) if err == nil { return val, nil } val2 := "" err = json.Unmarshal(value, &val2) if err == nil { return []string{val2}, nil } return nil, fmt.Errorf("Not a string or list of strings: %s", value) } func parseConsulResult(data []*consul.KVPair) (map[string]Entry, error) { aggregator := map[string]Entry{} for _, kv := range data { log.Printf("(parseConsulResult) %s %s", kv.Key, string(kv.Value)) dn, attr, val := consulToDN(kv) if attr == "" || val == nil { continue } if _, exists := aggregator[dn]; !exists { aggregator[dn] = Entry{} } value, err := parseValue(val) if err != nil { return nil, err } aggregator[dn][attr] = value } return aggregator, nil } type DNComponent struct { Type string Value string } func parseDN(dn string) ([]DNComponent, error) { rdns := strings.Split(dn, ",") ret := []DNComponent{} for _, rdn := range rdns { splits := strings.Split(rdn, "=") if len(splits) != 2 { return nil, fmt.Errorf("Wrong DN component: %s (expected type=value)", rdn) } ret = append(ret, DNComponent{ Type: splits[0], Value: splits[1], }) } return ret, nil } type Config struct { Suffix string } type Server struct { config Config kv *consul.KV } type State struct { bindDn string } type Entry map[string][]string func main() { //ldap logger ldap.Logger = log.New(os.Stdout, "[server] ", log.LstdFlags) // Connect to Consul client, err := consul.NewClient(consul.DefaultConfig()) if err != nil { panic(err) } kv := client.KV() // TODO read config from somewhere config := Config{ Suffix: "dc=gobottin,dc=eu", } gobottin := Server{config: config, kv: kv} err = gobottin.init() if err != nil { panic(err) } //Create a new LDAP Server ldapserver := ldap.NewServer() ldapserver.NewUserState = func() ldap.UserState { return &State{} } routes := ldap.NewRouteMux() routes.Bind(gobottin.handleBind) routes.Search(gobottin.handleSearch) routes.Add(gobottin.handleAdd) ldapserver.Handle(routes) // listen on 10389 go ldapserver.ListenAndServe("127.0.0.1:10389") // When CTRL+C, SIGINT and SIGTERM signal occurs // Then stop server gracefully ch := make(chan os.Signal) signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM) <-ch close(ch) ldapserver.Stop() } func (server *Server) init() error { pair, _, err := server.kv.Get(dnToConsul(server.config.Suffix)+"/attribute=objectClass", nil) if err != nil { return err } if pair != nil { return nil } base_attributes := Entry{ "objectClass": []string{"top", "dcObject", "organization"}, "structuralObjectClass": []string{"Organization"}, } suffix_dn, err := parseDN(server.config.Suffix) if err != nil { return err } base_attributes[suffix_dn[0].Type] = []string{suffix_dn[0].Value} err = server.addElements(server.config.Suffix, base_attributes) if err != nil { return err } admin_pass := make([]byte, 8) rand.Read(admin_pass) admin_pass_str := base64.RawURLEncoding.EncodeToString(admin_pass) admin_pass_hash := SSHAEncode([]byte(admin_pass_str)) admin_dn := "cn=admin," + server.config.Suffix admin_attributes := Entry{ "objectClass": []string{"simpleSecurityObject", "organizationalRole"}, "description": []string{"LDAP administrator"}, "cn": []string{"admin"}, "userpassword": []string{admin_pass_hash}, "structuralObjectClass": []string{"organizationalRole"}, "permissions": []string{"read", "write"}, } err = server.addElements(admin_dn, admin_attributes) if err != nil { return err } log.Printf( "It seems to be a new installation, we created a default user for you:\n\n dn: %s\n password: %s\n\nWe didn't use true random, you should replace it as soon as possible.", admin_dn, admin_pass_str, ) return nil } func (server *Server) addElements(dn string, attrs Entry) error { prefix := dnToConsul(dn) for k, v := range attrs { json, err := json.Marshal(v) if err != nil { return err } pair := &consul.KVPair{Key: prefix + "/attribute=" + k, Value: json} _, err = server.kv.Put(pair, nil) if err != nil { return err } } return nil } func (server *Server) getAttribute(dn string, attr string) ([]string, error) { pair, _, err := server.kv.Get(dnToConsul(dn) + "/attribute=" + attr, nil) if err != nil { return nil, err } if pair == nil { return nil, nil } return parseValue(pair.Value) } func (server *Server) objectExists(dn string) (bool, error) { prefix := dnToConsul(dn) + "/" data, _, err := server.kv.List(prefix, nil) if err != nil { return false, err } return len(data) > 0, nil } func (server *Server) checkSuffix(dn string, allow_extend bool) (string, error) { suffix := server.config.Suffix if len(dn) < len(suffix) { if dn != suffix[-len(dn):] || !allow_extend { return suffix, fmt.Errorf( "Only handling stuff under DN %s", suffix) } return suffix, nil } else { if dn[len(dn)-len(suffix):] != suffix { return suffix, fmt.Errorf( "Only handling stuff under DN %s", suffix) } return dn, nil } } func (server *Server) handleBind(s ldap.UserState, w ldap.ResponseWriter, m *ldap.Message) { state := s.(*State) r := m.GetBindRequest() result_code, err := server.handleBindInternal(state, &r) res := ldap.NewBindResponse(result_code) if err != nil { res.SetDiagnosticMessage(err.Error()) log.Printf("Failed bind for %s: %s", string(r.Name()), err.Error()) } w.Write(res) } func (server *Server) handleBindInternal(state *State, r *message.BindRequest) (int, error) { passwd, err := server.getAttribute(string(r.Name()), "userpassword") if err != nil { return ldap.LDAPResultOperationsError, err } if passwd == nil { return ldap.LDAPResultNoSuchObject, nil } for _, hash := range passwd { valid := SSHAMatches(hash, []byte(r.AuthenticationSimple())) if valid { state.bindDn = string(r.Name()) return ldap.LDAPResultSuccess, nil } } return ldap.LDAPResultInvalidCredentials, nil } func (server *Server) handleSearch(s ldap.UserState, w ldap.ResponseWriter, m *ldap.Message) { state := s.(*State) r := m.GetSearchRequest() code, err := server.handleSearchInternal(state, w, &r) res := ldap.NewResponse(code) if err != nil { res.SetDiagnosticMessage(err.Error()) } w.Write(message.SearchResultDone(res)) } func (server *Server) handleSearchInternal(state *State, w ldap.ResponseWriter, r *message.SearchRequest) (int, error) { log.Printf("-- SEARCH REQUEST: --") log.Printf("Request BaseDn=%s", r.BaseObject()) log.Printf("Request Filter=%s", r.Filter()) log.Printf("Request FilterString=%s", r.FilterString()) log.Printf("Request Attributes=%s", r.Attributes()) log.Printf("Request TimeLimit=%d", r.TimeLimit().Int()) // TODO check authorizations baseObject, err := server.checkSuffix(string(r.BaseObject()), true) if err != nil { return ldap.LDAPResultInvalidDNSyntax, err } basePath := dnToConsul(baseObject) + "/" data, _, err := server.kv.List(basePath, nil) if err != nil { return ldap.LDAPResultOperationsError, err } entries, err := parseConsulResult(data) if err != nil { return ldap.LDAPResultOperationsError, err } log.Printf("in %s: %#v", basePath, data) log.Printf("%#v", entries) for dn, entry := range entries { // TODO filter out if no permission to read this matched, err := applyFilter(entry, r.Filter()) if err != nil { return ldap.LDAPResultUnwillingToPerform, err } if !matched { continue } e := ldap.NewSearchResultEntry(dn) for attr, val := range entry { // If attribute is not in request, exclude it from returned entry if len(r.Attributes()) > 0 { found := false for _, requested := range r.Attributes() { if strings.EqualFold(string(requested), attr) { found = true break } } if !found { continue } } // Send result for _, v := range val { e.AddAttribute(message.AttributeDescription(attr), message.AttributeValue(v)) } } w.Write(e) } return ldap.LDAPResultSuccess, nil } func applyFilter(entry Entry, filter message.Filter) (bool, error) { if fAnd, ok := filter.(message.FilterAnd); ok { for _, cond := range fAnd { res, err := applyFilter(entry, cond) if err != nil { return false, err } if !res { return false, nil } } return true, nil } else if fOr, ok := filter.(message.FilterOr); ok { for _, cond := range fOr { res, err := applyFilter(entry, cond) if err != nil { return false, err } if res { return true, nil } } return false, nil } else if fNot, ok := filter.(message.FilterNot); ok { res, err := applyFilter(entry, fNot.Filter) if err != nil { return false, err } return !res, nil } else if fPresent, ok := filter.(message.FilterPresent); ok { what := string(fPresent) // Case insensitive search for desc := range entry { if strings.EqualFold(what, desc) { return true, nil } } return false, nil } else if fEquality, ok := filter.(message.FilterEqualityMatch); ok { desc := string(fEquality.AttributeDesc()) target := string(fEquality.AssertionValue()) // Case insensitive attribute search for entry_desc, value := range entry { if strings.EqualFold(entry_desc, desc) { for _, val := range value { if val == target { return true, nil } } return false, nil } } return false, nil } else { return false, fmt.Errorf("Unsupported filter: %#v %T", filter, filter) } } func (server *Server) handleAdd(s ldap.UserState, w ldap.ResponseWriter, m *ldap.Message) { state := s.(*State) r := m.GetAddRequest() code, err := server.handleAddInternal(state, &r) res := ldap.NewResponse(code) if err != nil { res.SetDiagnosticMessage(err.Error()) } w.Write(message.AddResponse(res)) } func (server *Server) handleAddInternal(state *State, r *message.AddRequest) (int, error) { dn := string(r.Entry()) _, err := server.checkSuffix(dn, false) if err != nil { return ldap.LDAPResultInvalidDNSyntax, err } exists, err := server.objectExists(dn) if err != nil { return ldap.LDAPResultOperationsError, err } if exists { return ldap.LDAPResultEntryAlreadyExists, nil } // TODO check permissions var members []string = nil entry := Entry{} for _, attribute := range r.Attributes() { key := string(attribute.Type_()) if strings.EqualFold(key, "memberOf") { return ldap.LDAPResultObjectClassViolation, fmt.Errorf( "memberOf cannot be defined directly, membership must be specified in the group itself") } vals_str := []string{} for _, val := range attribute.Vals() { vals_str = append(vals_str, string(val)) } if strings.EqualFold(key, "member") { members = vals_str for _, member := range members { _, err := server.checkSuffix(member, false) if err != nil { return ldap.LDAPResultInvalidDNSyntax, err } exists, err = server.objectExists(member) if err != nil { return ldap.LDAPResultOperationsError, err } if !exists { return ldap.LDAPResultNoSuchObject, fmt.Errorf( "Cannot add %s to members, it does not exist!", member) } } } entry[key] = vals_str } err = server.addElements(dn, entry) if err != nil { return ldap.LDAPResultOperationsError, err } if members != nil { for _, member := range members { memberGroups, err := server.getAttribute(member, "memberOf") if err != nil { return ldap.LDAPResultOperationsError, err } if memberGroups == nil { memberGroups = []string{} } memberGroups = append(memberGroups, dn) err = server.addElements(member, Entry{ "memberOf": memberGroups, }) if err != nil { return ldap.LDAPResultOperationsError, err } } } return ldap.LDAPResultSuccess, nil }