aboutsummaryrefslogtreecommitdiff
path: root/ldapserver/packet.go
blob: 24d01ed93b37c14b33fa887f2f45cc9d961304ce (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
package ldapserver

import (
	"bufio"
	"errors"
	"fmt"

	ldap "github.com/vjeantet/goldap/message"
)

type messagePacket struct {
	bytes []byte
}

func readMessagePacket(br *bufio.Reader) (*messagePacket, error) {
	var err error
	var bytes *[]byte
	bytes, err = readLdapMessageBytes(br)

	if err == nil {
		messagePacket := &messagePacket{bytes: *bytes}
		return messagePacket, err
	}
	return &messagePacket{}, err

}

func (msg *messagePacket) readMessage() (m ldap.LDAPMessage, err error) {
	defer func() {
		if r := recover(); r != nil {
			err = fmt.Errorf("invalid packet received hex=%x, %#v", msg.bytes, r)
		}
	}()

	return decodeMessage(msg.bytes)
}

func decodeMessage(bytes []byte) (ret ldap.LDAPMessage, err error) {
	defer func() {
		if e := recover(); e != nil {
			err = errors.New(fmt.Sprintf("%s", e))
		}
	}()
	zero := 0
	ret, err = ldap.ReadLDAPMessage(ldap.NewBytes(zero, bytes))
	return
}

// BELLOW SHOULD BE IN ROOX PACKAGE

func readLdapMessageBytes(br *bufio.Reader) (ret *[]byte, err error) {
	var bytes []byte
	var tagAndLength ldap.TagAndLength
	tagAndLength, err = readTagAndLength(br, &bytes)
	if err != nil {
		return
	}
	readBytes(br, &bytes, tagAndLength.Length)
	return &bytes, err
}

// readTagAndLength parses an ASN.1 tag and length pair from a live connection
// into a byte slice. It returns the parsed data and the new offset. SET and
// SET OF (tag 17) are mapped to SEQUENCE and SEQUENCE OF (tag 16) since we
// don't distinguish between ordered and unordered objects in this code.
func readTagAndLength(conn *bufio.Reader, bytes *[]byte) (ret ldap.TagAndLength, err error) {
	// offset = initOffset
	//b := bytes[offset]
	//offset++
	var b byte
	b, err = readBytes(conn, bytes, 1)
	if err != nil {
		return
	}
	ret.Class = int(b >> 6)
	ret.IsCompound = b&0x20 == 0x20
	ret.Tag = int(b & 0x1f)

	//	// If the bottom five bits are set, then the tag number is actually base 128
	//	// encoded afterwards
	//	if ret.tag == 0x1f {
	//		ret.tag, err = parseBase128Int(conn, bytes)
	//		if err != nil {
	//			return
	//		}
	//	}
	// We are expecting the LDAP sequence tag 0x30 as first byte
	if b != 0x30 {
		panic(fmt.Sprintf("Expecting 0x30 as first byte, but got %#x instead", b))
	}

	b, err = readBytes(conn, bytes, 1)
	if err != nil {
		return
	}
	if b&0x80 == 0 {
		// The length is encoded in the bottom 7 bits.
		ret.Length = int(b & 0x7f)
	} else {
		// Bottom 7 bits give the number of length bytes to follow.
		numBytes := int(b & 0x7f)
		if numBytes == 0 {
			err = ldap.SyntaxError{"indefinite length found (not DER)"}
			return
		}
		ret.Length = 0
		for i := 0; i < numBytes; i++ {

			b, err = readBytes(conn, bytes, 1)
			if err != nil {
				return
			}
			if ret.Length >= 1<<23 {
				// We can't shift ret.length up without
				// overflowing.
				err = ldap.StructuralError{"length too large"}
				return
			}
			ret.Length <<= 8
			ret.Length |= int(b)
			// Compat some lib which use go-ldap or someone else,
			// they encode int may have leading zeros when it's greater then 127
			// if ret.Length == 0 {
			// 	// DER requires that lengths be minimal.
			// 	err = ldap.StructuralError{"superfluous leading zeros in length"}
			// 	return
			// }
		}
	}

	return
}

// Read "length" bytes from the connection
// Append the read bytes to "bytes"
// Return the last read byte
func readBytes(conn *bufio.Reader, bytes *[]byte, length int) (b byte, err error) {
	newbytes := make([]byte, length)
	n, err := conn.Read(newbytes)
	if n != length {
		fmt.Errorf("%d bytes read instead of %d", n, length)
	} else if err != nil {
		return
	}
	*bytes = append(*bytes, newbytes...)
	b = (*bytes)[len(*bytes)-1]
	return
}