#!/usr/bin/env python3

# DEPENDENCY: python-consul
import consul

# DEPENDENCY: python-ldap
import ldap

# DEPENDENCY: passlib
from passlib.hash import ldap_salted_sha1

import os
import sys
import glob
import subprocess
import getpass
import base64
from secrets import token_bytes


"""
TODO: this will be a utility to handle secrets in the Consul database
for the various components of the Deuxfleurs infrastructure

Functionnalities:
- check that secrets are correctly configured
- help user fill in secrets
- create LDAP service users and fill in corresponding secrets
- maybe one day: manage SSL certificates and keys

It uses files placed in <module_name>/secrets/* to know what secrets
it should handle. These secret files contain directives for what to do
about these secrets.

Example directives:

USER <description>
(a secret that must be filled in by the user)

USER_LONG <description>
(the same, indicates that the secret fits on several lines)

CMD <command>
(a secret that is generated by running this command)

CMD_ONCE <command>
(same, but value is not changed when doing a regen)

CONST <constant value>
(the secret has a constant value set here)

CONST_LONG
<constant value, several lines>
(same)

SERVICE_DN <service name> <service description>
(the LDAP DN of a service user)

SERVICE_PASSWORD <service name>
(the LDAP password for the corresponding service user)

SSL_CERT <cert name> <list of domains>
(a SSL domain for the given domains)

SSL_KEY <cert name>
(the SSL key going with corresponding certificate)

RSA_PUBLIC_KEY <key name> <key description>
(a public RSA key)

RSA_PRIVATE_KEY <key name>
(the corresponding private RSA key)
"""


# Parameters
LDAP_URL = "ldap://localhost:1389"
SERVICE_DN_SUFFIX = "ou=services,ou=users,dc=deuxfleurs,dc=fr"
consul_server = consul.Consul()


# ----

USER             = "USER"
USER_LONG        = "USER_LONG"
CMD              = "CMD"
CMD_ONCE         = "CMD_ONCE"
CONST            = "CONST"
CONST_LONG       = "CONST_LONG"
SERVICE_DN       = "SERVICE_DN"
SERVICE_PASSWORD = "SERVICE_PASSWORD"
SSL_CERT         = "SSL_CERT"
SSL_KEY          = "SSL_KEY"
RSA_PUBLIC_KEY   = "RSA_PUBLIC_KEY"
RSA_PRIVATE_KEY  = "RSA_PRIVATE_KEY"

class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKCYAN = '\033[96m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'

def read_secret(key, file_path):
    lines = [l.strip() for l in open(file_path, "r")]
    if len(lines) == 0:
        print(bcolors.FAIL, "ERROR:", bcolors.ENDC, "Empty file in", file_path)
        sys.exit(-1)
    l0 = lines[0].split(" ")
    stype = l0[0]
    secret = {"type": stype, "key": key}
    if stype in [USER, USER_LONG]:
        secret["desc"] = " ".join(l0[1:])
    elif stype in [CMD, CMD_ONCE]:
        secret["cmd"] = " ".join(l0[1:])
    elif stype == CONST:
        secret["value"] = " ".join(l0[1:])
    elif stype == CONST_LONG:
        secret["value"] = "\n".join(lines[1:])
    elif stype in [SERVICE_DN, SERVICE_PASSWORD]:
        secret["service"] = l0[1]
        if stype == SERVICE_DN:
            secret["service_desc"] = " ".join(l0[2:])
    elif stype in [SSL_CERT, SSL_KEY]:
        secret["cert_name"] = l0[1]
        if stype == SSL_CERT:
            secret["cert_domains"] = l0[2:]
    elif stype in [RSA_PUBLIC_KEY, RSA_PRIVATE_KEY]:
        secret["key_name"] = l0[1]
        if stype == RSA_PUBLIC_KEY:
            secret["key_desc"] = " ".join(l0[2:])
    else:
        print(bcolors.FAIL, "ERROR:", bcolors.ENDC, "Invalid secret type", stype, "in", file_path)
        sys.exit(-1)

    return secret

def read_secrets(module_list):
    secrets = {}
    for mod in module_list:
        for file_path in glob.glob(mod.strip('/') + "/secrets/**", recursive=True):
            if os.path.isfile(file_path):
                key = '/'.join(file_path.split("/")[1:])
                secrets[key] = read_secret(key, file_path)
    return secrets

def get_secrets_services(secrets):
    services = {}
    for key, secret in secrets.items():
        if secret["type"] not in [SERVICE_DN, SERVICE_PASSWORD]:
            continue
        svc = secret["service"]
        print(svc, "@", key, bcolors.OKCYAN, "...", bcolors.ENDC)
        if svc not in services:
            services[svc] = {
                "dn": "cn=%s,%s"%(svc, SERVICE_DN_SUFFIX),
                "desc": "(not provided)",
                "pass": None,
                "dn_at": [],
                "pass_at": [],
            }
        if secret["type"] == SERVICE_DN:
            services[svc]["dn_at"].append(key)
            services[svc]["desc"] = secret["service_desc"]

        if secret["type"] == SERVICE_PASSWORD:
            services[svc]["pass_at"].append(key)
            _, data = consul_server.kv.get(key)
            if data is not None:
                if services[svc]["pass"] is None:
                    services[svc]["pass"] = data["Value"].decode('ascii').strip()

    return services

ldap_admin_conn = None
def get_ldap_admin_conn():
    global ldap_admin_conn
    if ldap_admin_conn is None:
        ldap_admin_conn = ldap.initialize(LDAP_URL)
        ldap_user = input("LDAP admin user (full DN, please!): ")
        ldap_pass = getpass.getpass("LDAP admin password: ")
        ldap_admin_conn.simple_bind_s(ldap_user, ldap_pass)
    return ldap_admin_conn

# ---- CHECK COMMAND ----

def check_secrets(module_list):
    secrets = read_secrets(module_list)
    print("Found", len(secrets), "secrets to check")
    print()

    check_secrets_presence(secrets)
    check_secrets_services(secrets)

def check_secrets_presence(secrets):
    print("Checking secrets presence...")
    for key in secrets.keys():
        _, data = consul_server.kv.get(key)
        if data is None:
            print(key, bcolors.FAIL, "x", bcolors.ENDC)
        else: 
            print(key, bcolors.OKGREEN, "✓", bcolors.ENDC)
    print()

def check_secrets_services(secrets):
    print("Checking secrets for LDAP service users...")
    services = get_secrets_services(secrets)

    for svc_name, svc in services.items():
        for dn_key in svc["dn_at"]:
            _, data = consul_server.kv.get(dn_key)
            if data is not None:
                got_val = data["Value"].decode('ascii').strip()
                if got_val != svc["dn"]:
                    print(svc_name, "wrong DN at", dn_key, bcolors.FAIL, "x", bcolors.ENDC)
                    print("got:", got_val, "instead of:", svc["dn"])

        if svc["pass"] is None:
            print(svc_name, bcolors.FAIL, "no password stored", bcolors.ENDC)
        else:
            for pass_key in svc["pass_at"]:
                _, data = consul_server.kv.get(pass_key)
                if data is not None:
                    got_val = data["Value"].decode('ascii').strip()
                    if got_val != svc["pass"]:
                        print(svc_name, "wrong pass at", dn_key, bcolors.FAIL, "x", bcolors.ENDC)

            l = ldap.initialize(LDAP_URL)
            try:
                l.simple_bind_s(svc["dn"], svc["pass"])
                print(svc_name, bcolors.OKGREEN, "✓", bcolors.ENDC)
            except Exception as e:
                print(svc_name, bcolors.FAIL, e, bcolors.ENDC)
    print()


# ---- GEN COMMAND ----

def gen_secrets(module_list, regen):
    secrets = read_secrets(module_list)
    print("Found", len(secrets), "secrets to check and maybe generate")
    print()

    gen_secrets_base(secrets, regen)
    gen_secrets_services(secrets, regen)

    check_secrets_presence(secrets)
    check_secrets_services(secrets)

def gen_secrets_base(secrets, regen):
    print("Filling in user secrets and cmd secrets...")

    for key, secret in secrets.items():
        _, data = consul_server.kv.get(key)
        if data is not None and not regen:
            continue

        if secret["type"] == USER:
            print("----")
            print(key)
            print("Description:", secret["desc"])
            print("Enter value for secret, or ^C to skip:")
            try:
                val = input().strip()
                consul_server.kv.put(key, val)
                print(bcolors.OKCYAN, "Value set.", bcolors.ENDC)
            except KeyboardInterrupt:
                print(bcolors.WARNING, "Skipped.", bcolors.ENDC)

        if secret["type"] == USER_LONG:
            print("----")
            print(key)
            print("Description:", secret["desc"])
            print("Enter value for secret, or ^C to skip:")
            print("THIS IS A LONG VALUE, ENTER SEVERAL LINES AND FINISH WITH A LINE CONTAINING A SINGLE .")
            try:
                lines = []
                while True:
                    line = input().strip()
                    if line == ".":
                        break
                    lines.append(line)
                val = "\n".join(lines)
                consul_server.kv.put(key, val)
                print(bcolors.OKCYAN, "Value set.", bcolors.ENDC)
            except KeyboardInterrupt:
                print(bcolors.WARNING, "Skipped.", bcolors.ENDC)

        if secret["type"] in [CONST, CONST_LONG]:
            print("----")
            print(key)
            print("Resetting to constant value.")
            consul_server.kv.put(key, secret["value"])
            print(bcolors.OKCYAN, "Value set.", bcolors.ENDC)

        if secret["type"] == CMD or (secret["type"] == CMD_ONCE and data is None):
            print("----")
            print(key)
            print("Executing command:", secret["cmd"])
            val = subprocess.check_output(["sh", "-c", secret["cmd"]])
            consul_server.kv.put(key, val)
            print(bcolors.OKCYAN, "Value set.", bcolors.ENDC)

    print()

def gen_secrets_services(secrets, regen):
    print("Generating LDAP service accounts...")
    services = get_secrets_services(secrets)

    for svc_name, svc in services.items():
        print("----")
        print("Service:", svc_name)
        print("Description:", svc["desc"])

        for dn_key in svc["dn_at"]:
            _, data = consul_server.kv.get(dn_key)
            if data is None or data["Value"].decode('ascii').strip() != svc["dn"]:
                print(bcolors.OKCYAN, "Setting DN", bcolors.ENDC, "at", dn_key)
                consul_server.kv.put(dn_key, svc["dn"])

        if svc["pass"] is None or regen:
            print(bcolors.OKCYAN, "Generating new password", bcolors.ENDC)
            svc["pass"] = base64.urlsafe_b64encode(token_bytes(12)).decode('ascii')

        l = ldap.initialize(LDAP_URL)
        try:
            l.simple_bind_s(svc["dn"], svc["pass"])
        except:
            fix_service_user(svc)

        for pass_key in svc["pass_at"]:
            _, data = consul_server.kv.get(pass_key)
            if data is None or data["Value"].decode('ascii').strip() != svc["pass"]:
                print(bcolors.OKCYAN, "Setting password", bcolors.ENDC, "at", pass_key)
                consul_server.kv.put(pass_key, svc["pass"])

    print()

def fix_service_user(svc):
    print("Fixing service user", svc["dn"], "...")
    l = get_ldap_admin_conn()
    res = l.search_s(svc["dn"], ldap.SCOPE_BASE, "objectclass=*")
    pass_crypt = ldap_salted_sha1.hash(svc["pass"])
    if res is None or len(res) == 0:
        print(bcolors.OKCYAN, "Creating entity...", bcolors.ENDC)
        l.add_s(svc["dn"],
                [
                    ("objectclass",     [b"person", b"top"]),
                    ("displayname",     [svc["desc"].encode('ascii')]),
                    ("userpassword",     [pass_crypt.encode('ascii')]),
                ])
    else:
        print(bcolors.OKCYAN, "Resetting entity password", bcolors.ENDC)
        l.modify_s(svc["dn"],
                [
                    (ldap.MOD_REPLACE, "userpassword", [pass_crypt.encode('ascii')])
                ])

# ---- MAIN ----

if __name__ == "__main__":
    for i, val in enumerate(sys.argv):
        if val == "check":
            check_secrets(sys.argv[i+1:])
            break
        elif val == "gen":
            gen_secrets(sys.argv[i+1:], False)
            break
        elif val == "regen":
            gen_secrets(sys.argv[i+1:], True)
            break
        else:
            print("Usage:")
            print("    secretmgr.py [check|gen|regen] <module name>...")