aboutsummaryrefslogtreecommitdiff
path: root/main.go
diff options
context:
space:
mode:
Diffstat (limited to 'main.go')
-rw-r--r--main.go80
1 files changed, 56 insertions, 24 deletions
diff --git a/main.go b/main.go
index 3d2d7f2..88ea3dd 100644
--- a/main.go
+++ b/main.go
@@ -38,9 +38,10 @@ type ConfigFile struct {
BindAddress string `json:"bind_address"`
ConsulHost string `json:"consul_host"`
Acl []string `json:"acl"`
- SSLCertFile string `json:"ssl_cert_file"`
- SSLKeyFile string `json:"ssl_key_file"`
- SSLServerName string `json:"ssl_server_name"`
+ TLSCertFile string `json:"tls_cert_file"`
+ TLSKeyFile string `json:"tls_key_file"`
+ TLSServerName string `json:"tls_server_name"`
+ UseStartTLS bool `json:"use_starttls"`
}
type Config struct {
@@ -50,7 +51,8 @@ type Config struct {
Acl ACL
- TlsConfig *tls.Config
+ TLSConfig *tls.Config
+ UseStartTLS bool
}
type Server struct {
@@ -92,14 +94,15 @@ func readConfig() Config {
BindAddress: config_file.BindAddress,
ConsulHost: config_file.ConsulHost,
Acl: acl,
+ UseStartTLS: config_file.UseStartTLS,
}
- if config_file.SSLCertFile != "" && config_file.SSLKeyFile != "" && config_file.SSLServerName != "" {
- cert_txt, err := ioutil.ReadFile(config_file.SSLCertFile)
+ if config_file.TLSCertFile != "" && config_file.TLSKeyFile != "" && config_file.TLSServerName != "" {
+ cert_txt, err := ioutil.ReadFile(config_file.TLSCertFile)
if err != nil {
panic(err)
}
- key_txt, err := ioutil.ReadFile(config_file.SSLKeyFile)
+ key_txt, err := ioutil.ReadFile(config_file.TLSKeyFile)
if err != nil {
panic(err)
}
@@ -107,13 +110,14 @@ func readConfig() Config {
if err != nil {
panic(err)
}
- ret.TlsConfig = &tls.Config{
- MinVersion: tls.VersionSSL30,
+ ret.TLSConfig = &tls.Config{
+ MinVersion: tls.VersionTLS10,
MaxVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{cert},
- ServerName: config_file.SSLServerName,
+ ServerName: config_file.TLSServerName,
}
-
+ } else {
+ log.Printf("Warning: no TLS configuration provided, running an insecure server.")
}
return ret
@@ -160,31 +164,42 @@ func main() {
}
routes := ldap.NewRouteMux()
+
routes.Bind(gobottin.handleBind)
routes.Search(gobottin.handleSearch)
routes.Add(gobottin.handleAdd)
routes.Compare(gobottin.handleCompare)
routes.Delete(gobottin.handleDelete)
routes.Modify(gobottin.handleModify)
+
+ if config.TLSConfig != nil && config.UseStartTLS {
+ routes.Extended(gobottin.handleStartTLS).
+ RequestName(ldap.NoticeOfStartTLS).Label("StartTLS")
+ }
+
ldapserver.Handle(routes)
- if config.TlsConfig != nil {
+ go func() {
+ // When CTRL+C, SIGINT and SIGTERM signal occurs
+ // Then stop server gracefully
+ ch := make(chan os.Signal)
+ signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM)
+ <-ch
+ close(ch)
+ ldapserver.Stop()
+ }()
+
+ if config.TLSConfig != nil && !config.UseStartTLS {
secureConn := func(s *ldap.Server) {
- s.Listener = tls.NewListener(s.Listener, config.TlsConfig)
+ s.Listener = tls.NewListener(s.Listener, config.TLSConfig)
}
- go ldapserver.ListenAndServe(config.BindAddress, secureConn)
+ err = ldapserver.ListenAndServe(config.BindAddress, secureConn)
} else {
- go ldapserver.ListenAndServe(config.BindAddress)
+ err = ldapserver.ListenAndServe(config.BindAddress)
+ }
+ if err != nil {
+ panic(err)
}
-
- // When CTRL+C, SIGINT and SIGTERM signal occurs
- // Then stop server gracefully
- ch := make(chan os.Signal)
- signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM)
- <-ch
- close(ch)
-
- ldapserver.Stop()
}
func (server *Server) init() error {
@@ -328,6 +343,23 @@ func (server *Server) checkSuffix(dn string, allow_extend bool) (string, error)
}
}
+func (server *Server) handleStartTLS(s ldap.UserState, w ldap.ResponseWriter, m *ldap.Message) {
+ tlsConn := tls.Server(m.Client.GetConn(), server.config.TLSConfig)
+ res := ldap.NewExtendedResponse(ldap.LDAPResultSuccess)
+ res.SetResponseName(ldap.NoticeOfStartTLS)
+ w.Write(res)
+
+ if err := tlsConn.Handshake(); err != nil {
+ log.Printf("StartTLS Handshake error %v", err)
+ res.SetDiagnosticMessage(fmt.Sprintf("StartTLS Handshake error : \"%s\"", err.Error()))
+ res.SetResultCode(ldap.LDAPResultOperationsError)
+ w.Write(res)
+ return
+ }
+
+ m.Client.SetConn(tlsConn)
+}
+
func (server *Server) handleBind(s ldap.UserState, w ldap.ResponseWriter, m *ldap.Message) {
state := s.(*State)
r := m.GetBindRequest()