From 66c64797706a2e62424c3523564b99f0597cde03 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Mon, 27 Jan 2020 16:08:35 +0100 Subject: Implement TLS mechanisms correctly, I hope --- main.go | 80 +++++++++++++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 56 insertions(+), 24 deletions(-) (limited to 'main.go') diff --git a/main.go b/main.go index 3d2d7f2..88ea3dd 100644 --- a/main.go +++ b/main.go @@ -38,9 +38,10 @@ type ConfigFile struct { BindAddress string `json:"bind_address"` ConsulHost string `json:"consul_host"` Acl []string `json:"acl"` - SSLCertFile string `json:"ssl_cert_file"` - SSLKeyFile string `json:"ssl_key_file"` - SSLServerName string `json:"ssl_server_name"` + TLSCertFile string `json:"tls_cert_file"` + TLSKeyFile string `json:"tls_key_file"` + TLSServerName string `json:"tls_server_name"` + UseStartTLS bool `json:"use_starttls"` } type Config struct { @@ -50,7 +51,8 @@ type Config struct { Acl ACL - TlsConfig *tls.Config + TLSConfig *tls.Config + UseStartTLS bool } type Server struct { @@ -92,14 +94,15 @@ func readConfig() Config { BindAddress: config_file.BindAddress, ConsulHost: config_file.ConsulHost, Acl: acl, + UseStartTLS: config_file.UseStartTLS, } - if config_file.SSLCertFile != "" && config_file.SSLKeyFile != "" && config_file.SSLServerName != "" { - cert_txt, err := ioutil.ReadFile(config_file.SSLCertFile) + if config_file.TLSCertFile != "" && config_file.TLSKeyFile != "" && config_file.TLSServerName != "" { + cert_txt, err := ioutil.ReadFile(config_file.TLSCertFile) if err != nil { panic(err) } - key_txt, err := ioutil.ReadFile(config_file.SSLKeyFile) + key_txt, err := ioutil.ReadFile(config_file.TLSKeyFile) if err != nil { panic(err) } @@ -107,13 +110,14 @@ func readConfig() Config { if err != nil { panic(err) } - ret.TlsConfig = &tls.Config{ - MinVersion: tls.VersionSSL30, + ret.TLSConfig = &tls.Config{ + MinVersion: tls.VersionTLS10, MaxVersion: tls.VersionTLS12, Certificates: []tls.Certificate{cert}, - ServerName: config_file.SSLServerName, + ServerName: config_file.TLSServerName, } - + } else { + log.Printf("Warning: no TLS configuration provided, running an insecure server.") } return ret @@ -160,31 +164,42 @@ func main() { } routes := ldap.NewRouteMux() + routes.Bind(gobottin.handleBind) routes.Search(gobottin.handleSearch) routes.Add(gobottin.handleAdd) routes.Compare(gobottin.handleCompare) routes.Delete(gobottin.handleDelete) routes.Modify(gobottin.handleModify) + + if config.TLSConfig != nil && config.UseStartTLS { + routes.Extended(gobottin.handleStartTLS). + RequestName(ldap.NoticeOfStartTLS).Label("StartTLS") + } + ldapserver.Handle(routes) - if config.TlsConfig != nil { + go func() { + // When CTRL+C, SIGINT and SIGTERM signal occurs + // Then stop server gracefully + ch := make(chan os.Signal) + signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM) + <-ch + close(ch) + ldapserver.Stop() + }() + + if config.TLSConfig != nil && !config.UseStartTLS { secureConn := func(s *ldap.Server) { - s.Listener = tls.NewListener(s.Listener, config.TlsConfig) + s.Listener = tls.NewListener(s.Listener, config.TLSConfig) } - go ldapserver.ListenAndServe(config.BindAddress, secureConn) + err = ldapserver.ListenAndServe(config.BindAddress, secureConn) } else { - go ldapserver.ListenAndServe(config.BindAddress) + err = ldapserver.ListenAndServe(config.BindAddress) + } + if err != nil { + panic(err) } - - // When CTRL+C, SIGINT and SIGTERM signal occurs - // Then stop server gracefully - ch := make(chan os.Signal) - signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM) - <-ch - close(ch) - - ldapserver.Stop() } func (server *Server) init() error { @@ -328,6 +343,23 @@ func (server *Server) checkSuffix(dn string, allow_extend bool) (string, error) } } +func (server *Server) handleStartTLS(s ldap.UserState, w ldap.ResponseWriter, m *ldap.Message) { + tlsConn := tls.Server(m.Client.GetConn(), server.config.TLSConfig) + res := ldap.NewExtendedResponse(ldap.LDAPResultSuccess) + res.SetResponseName(ldap.NoticeOfStartTLS) + w.Write(res) + + if err := tlsConn.Handshake(); err != nil { + log.Printf("StartTLS Handshake error %v", err) + res.SetDiagnosticMessage(fmt.Sprintf("StartTLS Handshake error : \"%s\"", err.Error())) + res.SetResultCode(ldap.LDAPResultOperationsError) + w.Write(res) + return + } + + m.Client.SetConn(tlsConn) +} + func (server *Server) handleBind(s ldap.UserState, w ldap.ResponseWriter, m *ldap.Message) { state := s.(*State) r := m.GetBindRequest() -- cgit v1.2.3