package main import ( "encoding/base64" "encoding/json" "fmt" "log" "math/rand" "os" "os/signal" "strings" "syscall" ldap "./ldapserver" consul "github.com/hashicorp/consul/api" message "github.com/vjeantet/goldap/message" ) func dnToConsul(dn string) string { rdns := strings.Split(dn, ",") // Reverse rdns for i, j := 0, len(rdns)-1; i < j; i, j = i+1, j-1 { rdns[i], rdns[j] = rdns[j], rdns[i] } return strings.Join(rdns, "/") } func consulToDN(pair *consul.KVPair) (string, string, []byte) { path := strings.Split(pair.Key, "/") dn := "" for _, cpath := range path { if cpath == "" { continue } kv := strings.Split(cpath, "=") if len(kv) == 2 && kv[0] == "attribute" { return dn, kv[1], pair.Value } if dn != "" { dn = "," + dn } dn = cpath + dn } return dn, "", nil } func parseConsulResult(data []*consul.KVPair) (map[string]Entry, error) { aggregator := map[string]Entry{} for _, kv := range data { log.Printf("%s %s", kv.Key, string(kv.Value)) dn, attr, val := consulToDN(kv) if attr == "" || val == nil { continue } if _, exists := aggregator[dn]; !exists { aggregator[dn] = Entry{} } var value interface{} err := json.Unmarshal(val, &value) if err != nil { return nil, err } if vlist, ok := value.([]interface{}); ok { vlist2 := []string{} for _, v := range vlist { if vstr, ok := v.(string); ok { vlist2 = append(vlist2, vstr) } else { return nil, fmt.Errorf("Not a string: %#v", v) } } aggregator[dn][attr] = vlist2 } else if vstr, ok := value.(string); ok { aggregator[dn][attr] = vstr } else { return nil, fmt.Errorf("Not a string or a list of strings: %#v", value) } } return aggregator, nil } type DNComponent struct { Type string Value string } func parseDN(dn string) ([]DNComponent, error) { rdns := strings.Split(dn, ",") ret := []DNComponent{} for _, rdn := range rdns { splits := strings.Split(rdn, "=") if len(splits) != 2 { return nil, fmt.Errorf("Wrong DN component: %s (expected type=value)", rdn) } ret = append(ret, DNComponent{ Type: splits[0], Value: splits[1], }) } return ret, nil } type Config struct { Suffix string } type Server struct { config Config kv *consul.KV } type State struct { bindDn string } type Entry map[string]interface{} func main() { //ldap logger ldap.Logger = log.New(os.Stdout, "[server] ", log.LstdFlags) // Connect to Consul client, err := consul.NewClient(consul.DefaultConfig()) if err != nil { panic(err) } kv := client.KV() // TODO read config from somewhere config := Config{ Suffix: "dc=gobottin,dc=eu", } gobottin := Server{config: config, kv: kv} err = gobottin.init() if err != nil { panic(err) } //Create a new LDAP Server ldapserver := ldap.NewServer() ldapserver.NewUserState = func() ldap.UserState { return &State{} } routes := ldap.NewRouteMux() routes.Bind(gobottin.handleBind) routes.Search(gobottin.handleSearch) ldapserver.Handle(routes) // listen on 10389 go ldapserver.ListenAndServe("127.0.0.1:10389") // When CTRL+C, SIGINT and SIGTERM signal occurs // Then stop server gracefully ch := make(chan os.Signal) signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM) <-ch close(ch) ldapserver.Stop() } func (server *Server) init() error { pair, _, err := server.kv.Get(dnToConsul(server.config.Suffix)+"/attribute=objectClass", nil) if err != nil { return err } if pair != nil { return nil } base_attributes := Entry{ "objectClass": []string{"top", "dcObject", "organization"}, "structuralObjectClass": "Organization", } suffix_dn, err := parseDN(server.config.Suffix) if err != nil { return err } base_attributes[suffix_dn[0].Type] = suffix_dn[0].Value err = server.addElements(server.config.Suffix, base_attributes) if err != nil { return err } admin_pass := make([]byte, 8) rand.Read(admin_pass) admin_pass_str := base64.RawURLEncoding.EncodeToString(admin_pass) admin_pass_hash := SSHAEncode([]byte(admin_pass_str)) admin_dn := "cn=admin," + server.config.Suffix admin_attributes := Entry{ "objectClass": []string{"simpleSecurityObject", "organizationalRole"}, "description": "LDAP administrator", "cn": "admin", "userpassword": admin_pass_hash, "structuralObjectClass": "organizationalRole", "permissions": []string{"read", "write"}, } err = server.addElements(admin_dn, admin_attributes) if err != nil { return err } log.Printf( "It seems to be a new installation, we created a default user for you:\n\n dn: %s\n password: %s\n\nWe didn't use true random, you should replace it as soon as possible.", admin_dn, admin_pass_str, ) return nil } func (server *Server) addElements(dn string, attrs Entry) error { prefix := dnToConsul(dn) for k, v := range attrs { json, err := json.Marshal(v) if err != nil { return err } pair := &consul.KVPair{Key: prefix + "/attribute=" + k, Value: json} _, err = server.kv.Put(pair, nil) if err != nil { return err } } return nil } func (server *Server) handleBind(s ldap.UserState, w ldap.ResponseWriter, m *ldap.Message) { state := s.(*State) r := m.GetBindRequest() result_code, err := server.handleBindInternal(state, w, &r) res := ldap.NewBindResponse(result_code) if err != nil { res.SetDiagnosticMessage(err.Error()) log.Printf("Failed bind for %s: %s", string(r.Name()), err.Error()) } w.Write(res) } func (server *Server) handleBindInternal(state *State, w ldap.ResponseWriter, r *message.BindRequest) (int, error) { pair, _, err := server.kv.Get(dnToConsul(string(r.Name()))+"/attribute=userpassword", nil) if err != nil { return ldap.LDAPResultOperationsError, err } if pair == nil { return ldap.LDAPResultNoSuchObject, nil } hash := "" err = json.Unmarshal(pair.Value, &hash) if err != nil { return ldap.LDAPResultOperationsError, err } valid := SSHAMatches(hash, []byte(r.AuthenticationSimple())) if valid { state.bindDn = string(r.Name()) return ldap.LDAPResultSuccess, nil } else { return ldap.LDAPResultInvalidCredentials, nil } } func (server *Server) handleSearch(s ldap.UserState, w ldap.ResponseWriter, m *ldap.Message) { state := s.(*State) r := m.GetSearchRequest() code, err := server.handleSearchInternal(state, w, &r) res := ldap.NewResponse(code) if err != nil { res.SetDiagnosticMessage(err.Error()) } w.Write(message.SearchResultDone(res)) } func (server *Server) handleSearchInternal(state *State, w ldap.ResponseWriter, r *message.SearchRequest) (int, error) { log.Printf("-- SEARCH REQUEST: --") log.Printf("Request BaseDn=%s", r.BaseObject()) log.Printf("Request Filter=%s", r.Filter()) log.Printf("Request FilterString=%s", r.FilterString()) log.Printf("Request Attributes=%s", r.Attributes()) log.Printf("Request TimeLimit=%d", r.TimeLimit().Int()) // TODO check authorizations basePath := dnToConsul(string(r.BaseObject())) + "/" data, _, err := server.kv.List(basePath, nil) if err != nil { return ldap.LDAPResultOperationsError, err } entries, err := parseConsulResult(data) if err != nil { return ldap.LDAPResultOperationsError, err } log.Printf("in %s: %#v", basePath, data) log.Printf("%#v", entries) for dn, entry := range entries { // TODO filter out if no permission to read this matched, err := applyFilter(entry, r.Filter()) if err != nil { return ldap.LDAPResultOperationsError, err } if !matched { continue } e := ldap.NewSearchResultEntry(dn) for attr, val := range entry { // If attribute is not in request, exclude it from returned entry if len(r.Attributes()) > 0 { found := false for _, need := range r.Attributes() { if string(need) == attr { found = true break } } if !found { continue } } // Send result if val_str, ok := val.(string); ok { e.AddAttribute(message.AttributeDescription(attr), message.AttributeValue(val_str)) } else if val_strlist, ok := val.([]string); ok { for _, v := range val_strlist { e.AddAttribute(message.AttributeDescription(attr), message.AttributeValue(v)) } } else { panic(fmt.Sprintf("Invalid value: %#v", val)) } } w.Write(e) } return ldap.LDAPResultSuccess, nil } func applyFilter(entry Entry, filter message.Filter) (bool, error) { if fAnd, ok := filter.(message.FilterAnd); ok { for _, cond := range fAnd { res, err := applyFilter(entry, cond) if err != nil { return false, err } if !res { return false, nil } } return true, nil } else if fOr, ok := filter.(message.FilterOr); ok { for _, cond := range fOr { res, err := applyFilter(entry, cond) if err != nil { return false, err } if res { return true, nil } } return false, nil } else if fNot, ok := filter.(message.FilterNot); ok { res, err := applyFilter(entry, fNot.Filter) if err != nil { return false, err } return !res, nil } else if fPresent, ok := filter.(message.FilterPresent); ok { what := string(fPresent) log.Printf("Present filter: %s", what) if _, ok := entry[what]; ok { return true, nil } return false, nil } else if fEquality, ok := filter.(message.FilterEqualityMatch); ok { desc := string(fEquality.AttributeDesc()) target := string(fEquality.AssertionValue()) if value, ok := entry[desc]; ok { if vstr, ok := value.(string); ok { // If we have one value for the key, match exactly return vstr == target, nil } else if vlist, ok := value.([]string); ok { // If we have several values for the key, one must match for _, val := range vlist { if val == target { return true, nil } } return false, nil } else { panic(fmt.Sprintf("Invalid value: %#v", value)) } } else { return false, nil } } else { return false, fmt.Errorf("Unsupported filter: %#v %T", filter, filter) } }