aboutsummaryrefslogtreecommitdiff
path: root/main.go
diff options
context:
space:
mode:
Diffstat (limited to 'main.go')
-rw-r--r--main.go119
1 files changed, 77 insertions, 42 deletions
diff --git a/main.go b/main.go
index ddd6552..a58b2a1 100644
--- a/main.go
+++ b/main.go
@@ -1,28 +1,28 @@
package main
import (
- "os"
- "strings"
- "flag"
- "log"
- "net/http"
- "io/ioutil"
- "encoding/json"
- "encoding/base64"
"crypto/rand"
"crypto/tls"
+ "encoding/base64"
+ "encoding/json"
+ "flag"
"html/template"
+ "io/ioutil"
+ "log"
+ "net/http"
+ "os"
+ "strings"
- "github.com/gorilla/sessions"
"github.com/go-ldap/ldap/v3"
+ "github.com/gorilla/sessions"
)
type ConfigFile struct {
HttpBindAddr string `json:"http_bind_addr"`
SessionKey string `json:"session_key"`
LdapServerAddr string `json:"ldap_server_addr"`
- LdapTLS bool `json:"ldap_tls"`
- UserFormat string `json:"user_format"`
+ LdapTLS bool `json:"ldap_tls"`
+ UserFormat string `json:"user_format"`
}
var configFlag = flag.String("config", "./config.json", "Configuration file path")
@@ -30,12 +30,13 @@ var configFlag = flag.String("config", "./config.json", "Configuration file path
var config *ConfigFile
const SESSION_NAME = "guichet_session"
-var store *sessions.CookieStore = nil
-func readConfig() ConfigFile{
+var store sessions.Store = nil
+
+func readConfig() ConfigFile {
key_bytes := make([]byte, 32)
n, err := rand.Read(key_bytes)
- if err!= nil || n != 32 {
+ if err != nil || n != 32 {
log.Fatal(err)
}
@@ -43,8 +44,8 @@ func readConfig() ConfigFile{
HttpBindAddr: ":9991",
SessionKey: base64.StdEncoding.EncodeToString(key_bytes),
LdapServerAddr: "ldap://127.0.0.1:389",
- LdapTLS: false,
- UserFormat: "cn=%s,ou=users,dc=example,dc=com",
+ LdapTLS: false,
+ UserFormat: "cn=%s,ou=users,dc=example,dc=com",
}
_, err = os.Stat(*configFlag)
@@ -87,9 +88,10 @@ func main() {
config_file := readConfig()
config = &config_file
- store = sessions.NewCookieStore([]byte(config.SessionKey))
+ store = sessions.NewFilesystemStore("", []byte(config.SessionKey))
http.HandleFunc("/", handleHome)
+ http.HandleFunc("/logout", handleLogout)
staticfiles := http.FileServer(http.Dir("static"))
http.Handle("/static/", http.StripPrefix("/static/", staticfiles))
@@ -102,9 +104,9 @@ func main() {
type LoginInfo struct {
Username string
- DN string
+ DN string
Password string
- }
+}
func logRequest(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -115,17 +117,23 @@ func logRequest(handler http.Handler) http.Handler {
func checkLogin(w http.ResponseWriter, r *http.Request) *LoginInfo {
session, err := store.Get(r, SESSION_NAME)
- if err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- return nil
- }
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return nil
+ }
- login_info, has_login_info := session.Values["login_info"]
- if !has_login_info {
+ username, ok := session.Values["login_username"]
+ password, ok2 := session.Values["login_password"]
+ user_dn, ok3 := session.Values["login_dn"]
+ if !(ok && ok2 && ok3) {
return handleLogin(w, r)
}
- return login_info.(*LoginInfo)
+ return &LoginInfo{
+ DN: user_dn.(string),
+ Username: username.(string),
+ Password: password.(string),
+ }
}
func ldapOpen(w http.ResponseWriter) *ldap.Conn {
@@ -149,18 +157,15 @@ func ldapOpen(w http.ResponseWriter) *ldap.Conn {
// Templates ----
type LoginFormData struct {
- Username string
+ Username string
ErrorMessage string
}
-var (
- templateLogin = template.Must(template.ParseFiles("templates/layout.html", "templates/login.html"))
- templateHome = template.Must(template.ParseFiles("templates/layout.html", "templates/home.html"))
-)
-
// Page handlers ----
func handleHome(w http.ResponseWriter, r *http.Request) {
+ templateHome := template.Must(template.ParseFiles("templates/layout.html", "templates/home.html"))
+
login := checkLogin(w, r)
if login == nil {
return
@@ -169,7 +174,29 @@ func handleHome(w http.ResponseWriter, r *http.Request) {
templateHome.Execute(w, login)
}
+func handleLogout(w http.ResponseWriter, r *http.Request) {
+ session, err := store.Get(r, SESSION_NAME)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+
+ delete(session.Values, "login_username")
+ delete(session.Values, "login_password")
+ delete(session.Values, "login_dn")
+
+ err = session.Save(r, w)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+
+ http.Redirect(w, r, "/", http.StatusFound)
+}
+
func handleLogin(w http.ResponseWriter, r *http.Request) *LoginInfo {
+ templateLogin := template.Must(template.ParseFiles("templates/layout.html", "templates/login.html"))
+
if r.Method == "GET" {
templateLogin.Execute(w, LoginFormData{})
return nil
@@ -177,23 +204,18 @@ func handleLogin(w http.ResponseWriter, r *http.Request) *LoginInfo {
r.ParseForm()
username := strings.Join(r.Form["username"], "")
+ password := strings.Join(r.Form["password"], "")
user_dn := strings.ReplaceAll(config.UserFormat, "%s", username)
- login_info := &LoginInfo{
- DN: user_dn,
- Username: username,
- Password: strings.Join(r.Form["password"], ""),
- }
-
l := ldapOpen(w)
if l == nil {
return nil
}
- err := l.Bind(user_dn, login_info.Password)
+ err := l.Bind(user_dn, password)
if err != nil {
templateLogin.Execute(w, LoginFormData{
- Username: username,
+ Username: username,
ErrorMessage: err.Error(),
})
return nil
@@ -206,8 +228,21 @@ func handleLogin(w http.ResponseWriter, r *http.Request) *LoginInfo {
return nil
}
- session.Values["login_info"] = login_info
- return login_info
+ session.Values["login_username"] = username
+ session.Values["login_password"] = password
+ session.Values["login_dn"] = user_dn
+
+ err = session.Save(r, w)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return nil
+ }
+
+ return &LoginInfo{
+ DN: user_dn,
+ Username: username,
+ Password: password,
+ }
} else {
http.Error(w, "Unsupported method", http.StatusBadRequest)
return nil