aboutsummaryrefslogblamecommitdiff
path: root/util.go
blob: cbb6b0f2b038c66efa1d7ee4111074bc085481fe (plain) (tree)
1
2
3
4
5
6
7
8
9




                       
                 
              
 
                                     
                                                
                                        

 






                         














                                                               
                                                     







                                                         
                                             





                                     
                                                                                             

 









































                                                                                                   


















                                                                            

                                                     


                                


                                                         
                                                    








                                            
































                                                                                                                          







                                                       








                                                   
package main

import (
	"encoding/json"
	"fmt"
	"strings"
	"time"

	uuid "github.com/google/uuid"
	consul "github.com/hashicorp/consul/api"
	log "github.com/sirupsen/logrus"
)

// DNs ----

type dnComponent struct {
	Type  string
	Value string
}

func dnToConsul(dn string) (string, error) {
	if strings.Contains(dn, "/") {
		return "", fmt.Errorf("DN %s contains a /", dn)
	}

	rdns := strings.Split(dn, ",")

	// Reverse rdns
	for i, j := 0, len(rdns)-1; i < j; i, j = i+1, j-1 {
		rdns[i], rdns[j] = rdns[j], rdns[i]
	}

	return strings.Join(rdns, "/"), nil
}

func consulToDN(key string) (string, string, error) {
	path := strings.Split(key, "/")
	dn := ""
	for _, cpath := range path {
		if cpath == "" {
			continue
		}
		kv := strings.Split(cpath, "=")
		if len(kv) == 2 && kv[0] == "attribute" {
			return dn, kv[1], nil
		}
		if dn != "" {
			dn = "," + dn
		}
		dn = cpath + dn
	}
	return "", "", fmt.Errorf("Consul key %s does not end with attribute=something", key)
}

func parseDN(dn string) ([]dnComponent, error) {
	rdns := strings.Split(dn, ",")

	ret := []dnComponent{}

	for _, rdn := range rdns {
		splits := strings.Split(rdn, "=")
		if len(splits) != 2 {
			return nil, fmt.Errorf("Wrong DN component: %s (expected type=value)", rdn)
		}
		ret = append(ret, dnComponent{
			Type:  strings.ToLower(strings.TrimSpace(splits[0])),
			Value: strings.ToLower(strings.TrimSpace(splits[1])),
		})
	}
	return ret, nil
}

func unparseDN(path []dnComponent) string {
	ret := ""
	for _, c := range path {
		if ret != "" {
			ret = ret + ","
		}
		ret = ret + c.Type + "=" + c.Value
	}
	return ret
}

func canonicalDN(dn string) (string, error) {
	path, err := parseDN(dn)
	if err != nil {
		return "", err
	}

	return unparseDN(path), nil
}

// Values

type Entry map[string][]string

func parseValue(value []byte) ([]string, error) {
	val := []string{}
	err := json.Unmarshal(value, &val)
	if err == nil {
		return val, nil
	}

	val2 := ""
	err = json.Unmarshal(value, &val2)
	if err == nil {
		return []string{val2}, nil
	}

	return nil, fmt.Errorf("Not a string or list of strings: %s", value)
}

func parseConsulResult(data []*consul.KVPair) (map[string]Entry, error) {
	aggregator := map[string]Entry{}

	for _, pair := range data {
		dn, attr, err := consulToDN(pair.Key)
		if err != nil {
			continue
		}
		if _, exists := aggregator[dn]; !exists {
			aggregator[dn] = Entry{}
		}
		value, err := parseValue(pair.Value)
		if err != nil {
			return nil, err
		}
		aggregator[dn][attr] = value
	}

	return aggregator, nil
}

func checkRestrictedAttr(attr string) error {
	RESTRICTED_ATTRS := []string{
		ATTR_MEMBEROF,
		ATTR_ENTRYUUID,
		ATTR_CREATORSNAME,
		ATTR_CREATETIMESTAMP,
		ATTR_MODIFIERSNAME,
		ATTR_MODIFYTIMESTAMP,
	}

	if strings.EqualFold(attr, ATTR_MEMBEROF) {
		return fmt.Errorf("memberOf cannot be defined directly, membership must be specified in the group itself")
	}

	for _, s := range RESTRICTED_ATTRS {
		if strings.EqualFold(attr, s) {
			return fmt.Errorf("Attribute %s is restricted and may only be set by the system", s)
		}
	}
	return nil
}

func genTimestamp() string {
	return time.Now().Format("20060102150405Z")
}

func genUuid() string {
	uuid, err := uuid.NewRandom()
	if err != nil {
		log.Panicf("UUID generation error: %s", err)
	}
	return uuid.String()
}

func valueMatch(attr, val1, val2 string) bool {
	if strings.EqualFold(attr, ATTR_USERPASSWORD) {
		return val1 == val2
	} else {
		return strings.EqualFold(val1, val2)
	}
}

func listContains(list []string, key string) bool {
	for _, v := range list {
		if key == v {
			return true
		}
	}
	return false
}